From f641698f1cbe22db63da4afd0a4421dc4ce6a1c5 Mon Sep 17 00:00:00 2001 From: Entejar Alam Date: Wed, 17 Sep 2025 09:18:05 -0500 Subject: [PATCH 1/9] Updated partition_tracker to track auxiliary data for CLogLog Ordinal BART model --- include/stochtree/partition_tracker.h | 25 ++++++++++++++++ src/partition_tracker.cpp | 42 +++++++++++++++++++++++++++ 2 files changed, 67 insertions(+) diff --git a/include/stochtree/partition_tracker.h b/include/stochtree/partition_tracker.h index 0790d87a..a2f5dd70 100644 --- a/include/stochtree/partition_tracker.h +++ b/include/stochtree/partition_tracker.h @@ -93,6 +93,17 @@ class ForestTracker { int GetNumFeatures() {return num_features_;} bool Initialized() {return initialized_;} + /*! + * \brief Ordinal auxiliary data management methods + * Methods to initialize, get, and set auxiliary data for cloglog ordinal bart models + * n_levels is the number of outcome levels for the ordinal response + * type_idx is the index of the type of auxiliary data (0: latent Z, 1: forest predictions, 2: cutpoints gamma, 3: cumsum exp of cutpoints) + */ + void InitializeOrdinalAuxData(data_size_t num_observations, int n_levels); + double GetOrdinalAuxData(int type_idx, data_size_t obs_idx) const; + void SetOrdinalAuxData(int type_idx, data_size_t obs_idx, double value); + std::vector& GetOrdinalAuxDataVector(int type_idx); + private: /*! \brief Mapper from observations to predicted values summed over every tree in a forest */ std::vector sum_predictions_; @@ -121,6 +132,20 @@ class ForestTracker { void UpdateSampleTrackersInternal(TreeEnsemble& forest, Eigen::MatrixXd& covariates); void UpdateSampleTrackersResidualInternalBasis(TreeEnsemble& forest, ForestDataset& dataset, ColumnVector& residual, bool is_mean_model); void UpdateSampleTrackersResidualInternalNoBasis(TreeEnsemble& forest, ForestDataset& dataset, ColumnVector& residual, bool is_mean_model); + + /*! + * \brief Track auxiliary data for cloglog ordinal bart models + * Vector of vectors to store these auxiliary data + * Each inner vector holds one type of data (order: Latent variable Z, Forest predictions, Cutpoints gamma, Cumsum exp of cutpoints) + */ + std::vector> ordinal_aux_data_vec_; + + /*! + * \brief Private helper methods for ordinal auxiliary data + * n_levels is the number of outcome levels for the ordinal response + * type_idx is the index of the type of auxiliary data (0: latent Z, 1: forest predictions, 2: cutpoints gamma, 3: cumsum exp of cutpoints) + */ + void ResizeOrdinalAuxData(data_size_t num_observations, int n_levels); }; /*! \brief Class storing sample-prediction map for each tree in an ensemble */ diff --git a/src/partition_tracker.cpp b/src/partition_tracker.cpp index 9d643380..9c2831fc 100644 --- a/src/partition_tracker.cpp +++ b/src/partition_tracker.cpp @@ -696,4 +696,46 @@ std::vector FeaturePresortPartition::NodeIndices(int node_id) { return out; } + +// ============================================================================ +// ORDINAL AUXILIARY DATA METHODS +// ============================================================================ + +double ForestTracker::GetOrdinalAuxData(int type_idx, data_size_t obs_idx) const { + // CHECK(IsValidOrdinalIndex(type_idx, obs_idx)); + return ordinal_aux_data_vec_[type_idx][obs_idx]; +} + +void ForestTracker::InitializeOrdinalAuxData(data_size_t num_observations, int n_levels) { + ResizeOrdinalAuxData(num_observations, n_levels); +} + +void ForestTracker::SetOrdinalAuxData(int type_idx, data_size_t obs_idx, double value) { + // CHECK(IsValidOrdinalIndex(type_idx, obs_idx)); + ordinal_aux_data_vec_[type_idx][obs_idx] = value; +} + +std::vector& ForestTracker::GetOrdinalAuxDataVector(int type_idx) { + // CHECK(IsValidOrdinalType(type_idx)); + return ordinal_aux_data_vec_[type_idx]; +} + +void ForestTracker::ResizeOrdinalAuxData(data_size_t num_observations, int n_levels) { + // 4 types of ordinal auxiliary data: latent Z, forest predictions, cutpoints gamma, cumsum exp of gammas + const int n_types = 4; + ordinal_aux_data_vec_.resize(n_types); + for (int i = 0; i < n_types; ++i) { + if (i < 2) { + // First two types (latent Z, forest predictions) are sized to num_observations + ordinal_aux_data_vec_[i].assign(num_observations, 0.0); + } else if (i == 2) { + // Cutpoints gamma: size n_levels - 1 + ordinal_aux_data_vec_[i].assign(n_levels - 1, 0.0); + } else if (i == 3) { + // Cumsum exp of gammas: size n_levels + ordinal_aux_data_vec_[i].assign(n_levels, 0.0); + } + } +} + } // namespace StochTree From e99791bf61546328946ddb7a3441c378b1dccd50 Mon Sep 17 00:00:00 2001 From: Entejar Alam Date: Wed, 17 Sep 2025 18:20:10 -0500 Subject: [PATCH 2/9] Added leaf model for CLogLog Ordinal BART --- include/stochtree/leaf_model.h | 252 ++++++++++++++++++++++++++++++++- src/leaf_model.cpp | 64 +++++++++ src/partition_tracker.cpp | 3 - 3 files changed, 310 insertions(+), 9 deletions(-) diff --git a/include/stochtree/leaf_model.h b/include/stochtree/leaf_model.h index 5359775d..a563cfe1 100644 --- a/include/stochtree/leaf_model.h +++ b/include/stochtree/leaf_model.h @@ -347,12 +347,14 @@ namespace StochTree { * 2. `kUnivariateRegressionLeafGaussian`: Every leaf node has a zero-centered univariate normal prior and every leaf is a linear model, multiplying the leaf parameter by a (fixed) basis. * 3. `kMultivariateRegressionLeafGaussian`: Every leaf node has a multivariate normal prior, centered around the zero vector, and every leaf is a linear model, matrix-multiplying the leaf parameters by a (fixed) basis vector. * 4. `kLogLinearVariance`: Every leaf node has a inverse gamma prior and every leaf is constant. + * 5. `kCloglogOrdinal`: Every leaf node has a log-gamma prior and every leaf is constant. */ enum ModelType { kConstantLeafGaussian, kUnivariateRegressionLeafGaussian, kMultivariateRegressionLeafGaussian, - kLogLinearVariance + kLogLinearVariance, + kCloglogOrdinal }; /*! \brief Sufficient statistic and associated operations for gaussian homoskedastic constant leaf outcome model */ @@ -969,6 +971,236 @@ class LogLinearVarianceLeafModel { GammaSampler gamma_sampler_; }; +/*! \brief Sufficient statistic and associated operations for complementary log-log ordinal BART model */ +class CloglogOrdinalSuffStat { + public: + data_size_t n; + double sum_Y_less_K; + double other_sum; + + /*! + * \brief Construct a new CloglogOrdinalSuffStat object, setting all sufficient statistics to zero + */ + CloglogOrdinalSuffStat() { + n = 0; + sum_Y_less_K = 0.0; + other_sum = 0.0; + } + + /*! + * \brief Accumulate data from observation `row_idx` into the sufficient statistics + * + * \param dataset Data object containing training data, including covariates + * \param outcome Data object containing the original ordinal outcome values, which are used to compute sufficient statistics + * \param tracker Tracking data structures that speed up sampler operations, synchronized with `active_forest` tracking a forest's state + * \param row_idx Index of the training data observation from which the sufficient statistics should be updated + * \param tree_idx Index of the tree being updated in the course of this sufficient statistic update + */ + void IncrementSuffStat(ForestDataset& dataset, Eigen::VectorXd& outcome, ForestTracker& tracker, data_size_t row_idx, int tree_idx) { + n += 1; + + // Get ordinal outcome value for this observation + unsigned int y = static_cast(outcome(row_idx)); + + // Get auxiliary data from tracker (assuming types: 0=latents Z, 1=forest predictions, 2=cutpoints gamma, 3=cumsum exp of gamma) + double Z = tracker.GetOrdinalAuxData(0, row_idx); // latent variables Z + double lambda_minus = tracker.GetOrdinalAuxData(1, row_idx); // forest predictions excluding current tree + + // Get cutpoints gamma and cumulative sum of exp(gamma) + const std::vector& gamma = tracker.GetOrdinalAuxDataVector(2); // cutpoints gamma + const std::vector& seg = tracker.GetOrdinalAuxDataVector(3); // cumsum exp of gamma + + int K = gamma.size() + 1; // Number of ordinal categories + + if (y == K - 1) { + other_sum += std::exp(lambda_minus) * seg[y]; // checked and it's correct + } else { + sum_Y_less_K += 1.0; + other_sum += std::exp(lambda_minus) * (Z * std::exp(gamma[y]) + seg[y]); // checked and it's correct + } + } + + /*! + * \brief Reset all of the sufficient statistics to zero + */ + void ResetSuffStat() { + n = 0; + sum_Y_less_K = 0.0; + other_sum = 0.0; + } + + /*! + * \brief Set the value of each sufficient statistic to the sum of the values provided by `lhs` and `rhs` + * + * \param lhs First sufficient statistic ("left hand side") + * \param rhs Second sufficient statistic ("right hand side") + */ + void AddSuffStat(CloglogOrdinalSuffStat& lhs, CloglogOrdinalSuffStat& rhs) { + n = lhs.n + rhs.n; + sum_Y_less_K = lhs.sum_Y_less_K + rhs.sum_Y_less_K; + other_sum = lhs.other_sum + rhs.other_sum; + } + + /*! + * \brief Set the value of each sufficient statistic to the difference between the values provided by `lhs` and those provided by `rhs` + * + * \param lhs First sufficient statistic ("left hand side") + * \param rhs Second sufficient statistic ("right hand side") + */ + void SubtractSuffStat(CloglogOrdinalSuffStat& lhs, CloglogOrdinalSuffStat& rhs) { + n = lhs.n - rhs.n; + sum_Y_less_K = lhs.sum_Y_less_K - rhs.sum_Y_less_K; + other_sum = lhs.other_sum - rhs.other_sum; + } + + /*! + * \brief Check whether accumulated sample size, `n`, is greater than some threshold + * + * \param threshold Value used to compute `n > threshold` + */ + bool SampleGreaterThan(data_size_t threshold) { + return n > threshold; + } + + /*! + * \brief Check whether accumulated sample size, `n`, is greater than or equal to some threshold + * + * \param threshold Value used to compute `n >= threshold` + */ + bool SampleGreaterThanEqual(data_size_t threshold) { + return n >= threshold; + } + + /*! + * \brief Return the sample size accumulated by a sufficient stat object + */ + data_size_t SampleSize() { + return n; + } +}; + +/*! \brief Marginal likelihood and posterior computation for complementary log-log ordinal BART model */ +class CloglogOrdinalLeafModel { + public: + /*! + * \brief Construct a new CloglogOrdinalLeafModel object + * + * \param a Shape parameter for log-gamma prior on leaf parameters + * \param b rate parameter for log-gamma prior on leaf parameters + * Log-gamma density: f(x) = b^a / Gamma(a) * exp(a*x - b*exp(x)) + * Relationship to tau (scale of leaf parameters): tau^2 = trigamma(a) + */ + CloglogOrdinalLeafModel(double a, double b) { + a_ = a; + b_ = b; + gamma_sampler_ = GammaSampler(); + tau_ = std::sqrt(boost::math::trigamma(a_)); + } + ~CloglogOrdinalLeafModel() {} + + /*! + * \brief Log marginal likelihood for a proposed split, evaluated only for observations that fall into the node being split. + */ + double SplitLogMarginalLikelihood(CloglogOrdinalSuffStat& left_stat, CloglogOrdinalSuffStat& right_stat, double global_variance); + + /*! + * \brief Log marginal likelihood of a node, evaluated only for observations that fall into the node being split. + */ + double NoSplitLogMarginalLikelihood(CloglogOrdinalSuffStat& suff_stat, double global_variance); + + /*! + * \brief Helper function to compute log marginal likelihood from sufficient statistics + */ + double SuffStatLogMarginalLikelihood(CloglogOrdinalSuffStat& suff_stat, double global_variance); + + /*! + * \brief Posterior shape parameter for leaf node log-gamma distribution + */ + double PosteriorParameterShape(CloglogOrdinalSuffStat& suff_stat, double global_variance); + + /*! + * \brief Posterior rate parameter for leaf node log-gamma distribution + */ + double PosteriorParameterRate(CloglogOrdinalSuffStat& suff_stat, double global_variance); + + /*! + * \brief Draw new parameters for every leaf node in `tree`, using a Gibbs update that conditions on the data, every other tree in the forest, and all model parameters. + * Samples from log-gamma: sample from gamma, then take log. + */ + void SampleLeafParameters(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, Tree* tree, int tree_num, double global_variance, std::mt19937& gen); + + void SetScale(double tau) {tau_ = tau;} + + /*! + * \brief Get the current scale parameter value (tau_) + * \return Current tau_ value + */ + double GetScale() const {return tau_;} + + inline bool RequiresBasis() {return false;} + + /*! + * \brief Convert tau_ (scale_lambda i.e. scale for leaf parameters) to alpha (shape) and beta (rate) parameters for the log-gamma prior + * + * \param alpha Output: shape parameter for log-gamma prior + * \param beta Output: rate parameter for log-gamma prior + * \param tau Scale parameter (tau_) for leaf parameters + */ + void ScaleTauToAlphaBeta(double& alpha, double& beta, const double tau) { + double tau_sq = tau * tau; + alpha = TrigammaInverse(tau_sq); + // Note: Using exponential of digamma function for beta calculation + beta = std::exp(boost::math::digamma(alpha)); + } + + /*! + * \brief Convert alpha (shape) and beta (rate) parameters (for the log-gamma prior) back to tau_ (scale_lambda i.e. scale for leaf parameters) + * + * \param alpha Shape parameter for log-gamma prior + * \param beta Rate parameter for log-gamma prior + * \return tau Scale parameter (tau_) for leaf parameters + */ + double AlphaBetaToScaleTau(double alpha, double beta) { + // Inverse of the transformation: tau_sq = trigamma(alpha) + double tau_sq = boost::math::trigamma(alpha); + return std::sqrt(tau_sq); + } + + private: + /*! + * \brief Compute inverse trigamma function using Newton's method + * + * Implementation adapted from limma package in R, originally by Gordon Smyth + * + * \param x Input value for which to compute trigamma inverse + * \return Value y such that trigamma(y) = x + */ + double TrigammaInverse(double x) { + // Very large and very small values - deal with using asymptotics + if (x > 1E7) { + return 1.0 / std::sqrt(x); + } + if (x < 1E-6) { + return 1.0 / x; + } + + // Otherwise, use Newton's method + double y = 0.5 + 1.0 / x; + for (int i = 0; i < 50; i++) { + double tri = boost::math::trigamma(y); + double dif = tri * (1.0 - tri / x) / boost::math::polygamma(3, y); // tetragamma is polygamma(3, x) + y += dif; + if (-dif / y < 1E-8) break; + } + + return y; + } + double a_; + double b_; + GammaSampler gamma_sampler_; + double tau_; +}; + /*! * \brief Unifying layer for disparate sufficient statistic class types * @@ -980,7 +1212,8 @@ class LogLinearVarianceLeafModel { using SuffStatVariant = std::variant; + LogLinearVarianceSuffStat, + CloglogOrdinalSuffStat>; /*! * \brief Unifying layer for disparate leaf model class types @@ -993,7 +1226,8 @@ using SuffStatVariant = std::variant; + LogLinearVarianceLeafModel, + CloglogOrdinalLeafModel>; template static inline SuffStatVariant createSuffStat(SuffStatConstructorArgs... leaf_suff_stat_args) { @@ -1018,8 +1252,10 @@ static inline SuffStatVariant suffStatFactory(ModelType model_type, int basis_di return createSuffStat(); } else if (model_type == kMultivariateRegressionLeafGaussian) { return createSuffStat(basis_dim); - } else { + } else if (model_type == kLogLinearVariance) { return createSuffStat(); + } else { + return createSuffStat(); } } @@ -1031,16 +1267,20 @@ static inline SuffStatVariant suffStatFactory(ModelType model_type, int basis_di * \param Sigma0 Value of the leaf node prior covariance matrix, only used if `model_type = kMultivariateRegressionLeafGaussian` * \param a Value of the leaf node inverse gamma prior shape parameter, only used if `model_type = kLogLinearVariance` * \param b Value of the leaf node inverse gamma prior scale parameter, only used if `model_type = kLogLinearVariance` + * \param c Value of the leaf node log-gamma prior shape parameter, only used if `model_type = kCloglogOrdinal` + * \param d Value of the leaf node log-gamma prior rate parameter, only used if `model_type = kCloglogOrdinal` */ -static inline LeafModelVariant leafModelFactory(ModelType model_type, double tau, Eigen::MatrixXd& Sigma0, double a, double b) { +static inline LeafModelVariant leafModelFactory(ModelType model_type, double tau, Eigen::MatrixXd& Sigma0, double a, double b, double c, double d) { if (model_type == kConstantLeafGaussian) { return createLeafModel(tau); } else if (model_type == kUnivariateRegressionLeafGaussian) { return createLeafModel(tau); } else if (model_type == kMultivariateRegressionLeafGaussian) { return createLeafModel(Sigma0); - } else { + } else if (model_type == kLogLinearVariance) { return createLeafModel(a, b); + } else { + return createLeafModel(c, d); } } diff --git a/src/leaf_model.cpp b/src/leaf_model.cpp index 3b59ab96..78d8da76 100644 --- a/src/leaf_model.cpp +++ b/src/leaf_model.cpp @@ -274,4 +274,68 @@ void LogLinearVarianceLeafModel::SetEnsembleRootPredictedValue(ForestDataset& da } } +// ============================================================================ +// Cloglog Ordinal Leaf Model +// ============================================================================ + +double CloglogOrdinalLeafModel::SplitLogMarginalLikelihood(CloglogOrdinalSuffStat& left_stat, CloglogOrdinalSuffStat& right_stat, double global_variance) { + double left_log_ml = SuffStatLogMarginalLikelihood(left_stat, global_variance); + double right_log_ml = SuffStatLogMarginalLikelihood(right_stat, global_variance); + return left_log_ml + right_log_ml; +} + +double CloglogOrdinalLeafModel::NoSplitLogMarginalLikelihood(CloglogOrdinalSuffStat& suff_stat, double global_variance) { + return SuffStatLogMarginalLikelihood(suff_stat, global_variance); +} + +double CloglogOrdinalLeafModel::SuffStatLogMarginalLikelihood(CloglogOrdinalSuffStat& suff_stat, double global_variance) { + double prior_terms = a_ * std::log(b_) - boost::math::lgamma(a_); + double a_term = a_ + suff_stat.sum_Y_less_K; + double b_term = b_ + suff_stat.other_sum; + double log_b_term = std::log(b_term); + double lgamma_a_term = boost::math::lgamma(a_term); + double resid_term = a_term * log_b_term; + double log_ml = prior_terms + lgamma_a_term - resid_term; + return log_ml; +} + +double CloglogOrdinalLeafModel::PosteriorParameterShape(CloglogOrdinalSuffStat& suff_stat, double global_variance) { + return a_ + suff_stat.sum_Y_less_K; +} + +double CloglogOrdinalLeafModel::PosteriorParameterRate(CloglogOrdinalSuffStat& suff_stat, double global_variance) { + return b_ + suff_stat.other_sum; +} + +void CloglogOrdinalLeafModel::SampleLeafParameters(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, Tree* tree, int tree_num, double global_variance, std::mt19937& gen) { + // Vector of leaf indices for tree + std::vector tree_leaves = tree->GetLeaves(); + + // Initialize sufficient statistics + CloglogOrdinalSuffStat node_suff_stat = CloglogOrdinalSuffStat(); + + // Sample each leaf node parameter + double node_shape; + double node_rate; + double node_mu; + int32_t leaf_id; + for (int i = 0; i < tree_leaves.size(); i++) { + // Compute leaf node sufficient statistics + leaf_id = tree_leaves[i]; + node_suff_stat.ResetSuffStat(); + AccumulateSingleNodeSuffStat(node_suff_stat, dataset, tracker, residual, tree_num, leaf_id); + + // Compute posterior shape and rate + node_shape = PosteriorParameterShape(node_suff_stat, global_variance); + node_rate = PosteriorParameterRate(node_suff_stat, global_variance); + + // Draw from log-gamma dist(node_shape, node_rate) and set the leaf parameter with each draw + // std::gamma_distribution gamma_dist_(node_shape, 1.); + // node_mu = -std::log(gamma_sample / node_rate); + double gamma_sample = gamma_sampler_.Sample(node_shape, node_rate, gen); + node_mu = std::log(gamma_sample); + tree->SetLeaf(leaf_id, node_mu); + } +} + } // namespace StochTree diff --git a/src/partition_tracker.cpp b/src/partition_tracker.cpp index 9c2831fc..8359faed 100644 --- a/src/partition_tracker.cpp +++ b/src/partition_tracker.cpp @@ -702,7 +702,6 @@ std::vector FeaturePresortPartition::NodeIndices(int node_id) { // ============================================================================ double ForestTracker::GetOrdinalAuxData(int type_idx, data_size_t obs_idx) const { - // CHECK(IsValidOrdinalIndex(type_idx, obs_idx)); return ordinal_aux_data_vec_[type_idx][obs_idx]; } @@ -711,12 +710,10 @@ void ForestTracker::InitializeOrdinalAuxData(data_size_t num_observations, int n } void ForestTracker::SetOrdinalAuxData(int type_idx, data_size_t obs_idx, double value) { - // CHECK(IsValidOrdinalIndex(type_idx, obs_idx)); ordinal_aux_data_vec_[type_idx][obs_idx] = value; } std::vector& ForestTracker::GetOrdinalAuxDataVector(int type_idx) { - // CHECK(IsValidOrdinalType(type_idx)); return ordinal_aux_data_vec_[type_idx]; } From 8f77e153370a634925278db1cb276ca65d750d46 Mon Sep 17 00:00:00 2001 From: Entejar Alam Date: Wed, 17 Sep 2025 18:34:44 -0500 Subject: [PATCH 3/9] Added ordinal_sampler --- include/stochtree/ordinal_sampler.h | 86 ++++++++++++++++++++++++++ src/ordinal_sampler.cpp | 94 +++++++++++++++++++++++++++++ 2 files changed, 180 insertions(+) create mode 100644 include/stochtree/ordinal_sampler.h create mode 100644 src/ordinal_sampler.cpp diff --git a/include/stochtree/ordinal_sampler.h b/include/stochtree/ordinal_sampler.h new file mode 100644 index 00000000..ec148b5b --- /dev/null +++ b/include/stochtree/ordinal_sampler.h @@ -0,0 +1,86 @@ +/*! + * Copyright (c) 2024 stochtree authors. All rights reserved. + * Licensed under the MIT License. See LICENSE file in the project root for license information. + */ +#ifndef STOCHTREE_ORDINAL_SAMPLER_H_ +#define STOCHTREE_ORDINAL_SAMPLER_H_ + +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace StochTree { + +/*! + * \brief Sampler for ordinal model hyperparameters + * + * This class handles MCMC sampling for ordinal-specific parameters: + * - Truncated exponential latent variables (Z) + * - Cutpoint parameters (gamma) + * - Cumulative sum of exp(gamma) (seg) [derived parameter] + */ +class OrdinalSampler { + public: + OrdinalSampler() { + gamma_sampler_ = GammaSampler(); + } + ~OrdinalSampler() {} + + /*! + * \brief Sample from truncated exponential distribution + * + * Samples from exponential distribution truncated to [0,1] + * + * \param lambda Rate parameter for exponential distribution + * \param gen Random number generator + * \return Sampled value from truncated exponential + */ + static double SampleTruncatedExponential(double lambda, std::mt19937& gen); + + + /*! + * \brief Update truncated exponential latent variables (Z) + * + * \param dataset Forest dataset containing training data (covariates) + * \param outcome Vector of outcome values + * \param tracker Forest tracker containing auxiliary data + * \param gen Random number generator + */ + void UpdateLatentVariables(ForestDataset& dataset, Eigen::VectorXd& outcome, ForestTracker& tracker, + std::mt19937& gen); + + /*! + * \brief Update gamma cutpoint parameters + * + * \param dataset Forest dataset containing training data (covariates) + * \param outcome Vector of outcome values + * \param tracker Forest tracker containing auxiliary data + * \param alpha_gamma Shape parameter for log-gamma prior on cutpoints gamma + * \param beta_gamma Rate parameter for log-gamma prior on cutpoints gamma + * \param gamma_0 Fixed value for first cutpoint parameter (for identifiability) + * \param gen Random number generator + */ + void UpdateGammaParams(ForestDataset& dataset, Eigen::VectorXd& outcome, ForestTracker& tracker, + double alpha_gamma, double beta_gamma, double gamma_0, + std::mt19937& gen); + + /*! + * \brief Update cumulative exponential sums (seg) + * + * \param tracker Forest tracker containing auxiliary data + */ + void UpdateCumulativeExpSums(ForestTracker& tracker); + + private: + GammaSampler gamma_sampler_; +}; + +} // namespace StochTree + +#endif // STOCHTREE_ORDINAL_SAMPLER_H_ diff --git a/src/ordinal_sampler.cpp b/src/ordinal_sampler.cpp new file mode 100644 index 00000000..27b11f63 --- /dev/null +++ b/src/ordinal_sampler.cpp @@ -0,0 +1,94 @@ +#include +#include + +namespace StochTree { + +double OrdinalSampler::SampleTruncatedExponential(double lambda, std::mt19937& gen) { + std::uniform_real_distribution unif(0.0, 1.0); + double u = unif(gen); + double a = 1.0 - u * (1.0 - std::exp(-lambda)); + return -std::log(a) / lambda; +} + +void OrdinalSampler::UpdateLatentVariables(ForestDataset& dataset, Eigen::VectorXd& outcome, ForestTracker& tracker, + std::mt19937& gen) { + // Get auxiliary data vectors + const std::vector& gamma = tracker.GetOrdinalAuxDataVector(2); // gamma cutpoints + const std::vector& lambda_hat = tracker.GetOrdinalAuxDataVector(1); // forest predictions: lambda_hat_i = sum_t lambda_t(x_i) + std::vector& Z = tracker.GetOrdinalAuxDataVector(0); // latent variables: z_i ~ TExp(e^{gamma[y_i] + lambda_hat_i}; 0, 1) + + int K = gamma.size() + 1; // Number of ordinal categories + int N = dataset.NumObservations(); + + // Update truncated exponentials (stored in latent auxiliary data slot 0) + // z_i ~ TExp(rate = e^{gamma[y_i] + lambda_hat_i}; 0, 1) + // where y_i is the ordinal outcome for observation i: make sure y_i converted to {0, 1, ..., K-1} + // and lambda_hat_i is the total forest prediction for observation i + // If y_i = K-1 (last category), then we set z_i = 1.0 deterministically just for bookkeeping, we don't need it + // We only need to sample latent z_i for y_i < K-1 (as z_i is only used in the likelihood for y_i < K-1) + for (int i = 0; i < N; i++) { + int y = static_cast(outcome(i)); + if (y == K - 1) { + Z[i] = 1.0; + } else { + double rate = std::exp(gamma[y] + lambda_hat[i]); + Z[i] = SampleTruncatedExponential(rate, gen); + } + } +} + +void OrdinalSampler::UpdateGammaParams(ForestDataset& dataset, Eigen::VectorXd& outcome, ForestTracker& tracker, + double alpha_gamma, double beta_gamma, double gamma_0, + std::mt19937& gen) { + // Get auxiliary data vectors + std::vector& gamma = tracker.GetOrdinalAuxDataVector(2); // cutpoints gamma_k's + const std::vector& Z = tracker.GetOrdinalAuxDataVector(0); // latent variables z_i's + const std::vector& lambda_hat = tracker.GetOrdinalAuxDataVector(1); // forest predictions: lambda_hat_i = sum_t lambda_t(x_i) + + int K = gamma.size() + 1; // Number of ordinal categories + int N = dataset.NumObservations(); + + // Compute sufficient statistics A[k] and B[k] for gamma[k] update + std::vector A(K - 1, 0.0); + std::vector B(K - 1, 0.0); + + for (int i = 0; i < N; i++) { + int y = static_cast(outcome(i)); + if (y < K - 1) { + A[y] += 1.0; + B[y] += Z[i] * std::exp(lambda_hat[i]); + } + for (int k = 0; k < y; k++) { + B[k] += std::exp(lambda_hat[i]); + } + } + + // Update gamma parameters using log-gamma sampling + // First sample all gamma parameters + for (int k = 0; k < static_cast(gamma.size()); k++) { + double shape = A[k] + alpha_gamma; + double rate = B[k] + beta_gamma; + double gamma_sample = gamma_sampler_.Sample(shape, rate, gen); + gamma[k] = std::log(gamma_sample); + } + + // Set the first gamma parameter to gamma_0 (e.g., 0) for identifiability + gamma[0] = gamma_0; +} + +void OrdinalSampler::UpdateCumulativeExpSums(ForestTracker& tracker) { + // Get auxiliary data vectors + const std::vector& gamma = tracker.GetOrdinalAuxDataVector(2); // cutpoints gamma_k's + std::vector& seg = tracker.GetOrdinalAuxDataVector(3); // seg_k = sum_{j=0}^{k-1} exp(gamma_j) + + // Update seg (sum of exponentials of gamma cutpoints) + for (int j = 0; j < static_cast(seg.size()); j++) { + if (j == 0) { + seg[j] = 0.0; // checked and it is correct + } else { + seg[j] = seg[j - 1] + std::exp(gamma[j - 1]); // checked and it is correct + } + } +} + +} // namespace StochTree From 8547425cf1abfb95c3c999316f0ffd5b147b7720 Mon Sep 17 00:00:00 2001 From: Entejar Alam Date: Wed, 17 Sep 2025 18:50:19 -0500 Subject: [PATCH 4/9] Updated tree_sampler.h Added functionality to adjust the model states before/after tree sampling for CLogLog Ordinal BART --- include/stochtree/tree_sampler.h | 42 ++++++++++++++++++++++++++++++-- 1 file changed, 40 insertions(+), 2 deletions(-) diff --git a/include/stochtree/tree_sampler.h b/include/stochtree/tree_sampler.h index 68c9c15a..675ef6c0 100644 --- a/include/stochtree/tree_sampler.h +++ b/include/stochtree/tree_sampler.h @@ -394,6 +394,40 @@ static inline void UpdateVarModelTree(ForestTracker& tracker, ForestDataset& dat } } +static inline void UpdateCLogLogModelTree(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, Tree* tree, int tree_num, + bool requires_basis, bool tree_new) { + data_size_t n = dataset.GetCovariates().rows(); + + double pred_value; + int32_t leaf_pred; + double pred_delta; + for (data_size_t i = 0; i < n; i++) { + if (tree_new) { + // If the tree has been newly sampled or adjusted, we must rerun the prediction + // method and update the SamplePredMapper stored in tracker + leaf_pred = tracker.GetNodeId(i, tree_num); + if (requires_basis) { + pred_value = tree->PredictFromNode(leaf_pred, dataset.GetBasis(), i); + } else { + pred_value = tree->PredictFromNode(leaf_pred); + } + pred_delta = pred_value - tracker.GetTreeSamplePrediction(i, tree_num); + tracker.SetTreeSamplePrediction(i, tree_num, pred_value); + tracker.SetSamplePrediction(i, tracker.GetSamplePrediction(i) + pred_delta); + // Set auxiliary data slot 1 to forest predictions excluding the current tree (tree_num) + tracker.SetOrdinalAuxData(1, i, tracker.GetSamplePrediction(i) - pred_value); + } else { + // If the tree has not yet been modified via a sampling step, + // we can query its prediction directly from the SamplePredMapper stored in tracker + pred_value = tracker.GetTreeSamplePrediction(i, tree_num); + // Set auxiliary data slot 1 to forest predictions excluding the current tree (tree_num): needed? since tree not changed? + double current_lambda_hat = tracker.GetSamplePrediction(i); + double lambda_minus = current_lambda_hat - pred_value; + tracker.SetOrdinalAuxData(1, i, lambda_minus); + } + } +} + template static inline std::tuple EvaluateProposedSplit( ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, LeafModel& leaf_model, @@ -448,7 +482,9 @@ static inline std::tuple EvaluateExist template static inline void AdjustStateBeforeTreeSampling(ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, TreePrior& tree_prior, bool backfitting, Tree* tree, int tree_num) { - if (backfitting) { + if constexpr (std::is_same_v) { + UpdateCLogLogModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), false); + } else if (backfitting) { UpdateMeanModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), std::plus(), false); } else { // TODO: think about a generic way to store "state" corresponding to the other models? @@ -459,7 +495,9 @@ static inline void AdjustStateBeforeTreeSampling(ForestTracker& tracker, LeafMod template static inline void AdjustStateAfterTreeSampling(ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, TreePrior& tree_prior, bool backfitting, Tree* tree, int tree_num) { - if (backfitting) { + if constexpr (std::is_same_v) { + UpdateCLogLogModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), true); + } else if (backfitting) { UpdateMeanModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), std::minus(), true); } else { // TODO: think about a generic way to store "state" corresponding to the other models? From 6c1d3ce549af37a3cbe8e7e86372b37a79dfde1a Mon Sep 17 00:00:00 2001 From: Entejar Alam Date: Wed, 17 Sep 2025 18:57:10 -0500 Subject: [PATCH 5/9] Updated sampler.cpp --- src/sampler.cpp | 101 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 101 insertions(+) diff --git a/src/sampler.cpp b/src/sampler.cpp index 212ccb42..ee8bd6e6 100644 --- a/src/sampler.cpp +++ b/src/sampler.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include [[cpp11::register]] @@ -326,3 +327,103 @@ cpp11::writable::integers sample_without_replacement_integer_cpp( // Return result return(output); } + +// ============================================================================ +// ORDINAL AUXILIARY DATA FUNCTIONS +// ============================================================================ + +[[cpp11::register]] +void ordinal_aux_data_initialize_cpp(cpp11::external_pointer tracker_ptr, StochTree::data_size_t num_observations, int n_levels) { + tracker_ptr->InitializeOrdinalAuxData(num_observations, n_levels); +} + +[[cpp11::register]] +double ordinal_aux_data_get_cpp(cpp11::external_pointer tracker_ptr, int type_idx, StochTree::data_size_t obs_idx) { + return tracker_ptr->GetOrdinalAuxData(type_idx, obs_idx); +} + +[[cpp11::register]] +void ordinal_aux_data_set_cpp(cpp11::external_pointer tracker_ptr, int type_idx, StochTree::data_size_t obs_idx, double value) { + tracker_ptr->SetOrdinalAuxData(type_idx, obs_idx, value); +} + +[[cpp11::register]] +cpp11::writable::doubles ordinal_aux_data_get_vector_cpp(cpp11::external_pointer tracker_ptr, int type_idx) { + const std::vector& aux_vec = tracker_ptr->GetOrdinalAuxDataVector(type_idx); + cpp11::writable::doubles output(aux_vec.size()); + for (size_t i = 0; i < aux_vec.size(); i++) { + output[i] = aux_vec[i]; + } + return output; +} + +[[cpp11::register]] +void ordinal_aux_data_set_vector_cpp(cpp11::external_pointer tracker_ptr, int type_idx, cpp11::doubles values) { + std::vector& aux_vec = tracker_ptr->GetOrdinalAuxDataVector(type_idx); + if (aux_vec.size() != values.size()) { + cpp11::stop("Size mismatch between auxiliary data vector and input values"); + } + for (size_t i = 0; i < values.size(); i++) { + aux_vec[i] = values[i]; + } +} + +[[cpp11::register]] +void ordinal_aux_data_update_cumsum_exp_cpp(cpp11::external_pointer tracker_ptr) { + // Get auxiliary data vectors + const std::vector& gamma = tracker_ptr->GetOrdinalAuxDataVector(2); // cutpoints gamma + std::vector& seg = tracker_ptr->GetOrdinalAuxDataVector(3); // cumsum exp gamma + + // Update seg (cumulative sum of exp(gamma)) + for (size_t j = 0; j < seg.size(); j++) { + if (j == 0) { + seg[j] = 0.0; + } else { + seg[j] = seg[j - 1] + std::exp(gamma[j - 1]); + } + } +} + +// ============================================================================ +// ORDINAL SAMPLER FUNCTIONS +// ============================================================================ + +[[cpp11::register]] +cpp11::external_pointer ordinal_sampler_cpp() { + std::unique_ptr sampler_ptr = std::make_unique(); + return cpp11::external_pointer(sampler_ptr.release()); +} + +[[cpp11::register]] +void ordinal_sampler_update_latent_variables_cpp( + cpp11::external_pointer sampler_ptr, + cpp11::external_pointer data_ptr, + cpp11::external_pointer outcome_ptr, + cpp11::external_pointer tracker_ptr, + cpp11::external_pointer rng_ptr +) { + sampler_ptr->UpdateLatentVariables(*data_ptr, outcome_ptr->GetData(), *tracker_ptr, *rng_ptr); +} + +[[cpp11::register]] +void ordinal_sampler_update_gamma_params_cpp( + cpp11::external_pointer sampler_ptr, + cpp11::external_pointer data_ptr, + cpp11::external_pointer outcome_ptr, + cpp11::external_pointer tracker_ptr, + double alpha_gamma, + double beta_gamma, + double gamma_0, + cpp11::external_pointer rng_ptr +) { + sampler_ptr->UpdateGammaParams(*data_ptr, outcome_ptr->GetData(), *tracker_ptr, alpha_gamma, beta_gamma, gamma_0, *rng_ptr); +} + +[[cpp11::register]] +void ordinal_sampler_update_cumsum_exp_cpp( + cpp11::external_pointer sampler_ptr, + cpp11::external_pointer tracker_ptr +) { + sampler_ptr->UpdateCumulativeExpSums(*tracker_ptr); +} + From 084be881caa7202decfd505ed29a64f348e8ba21 Mon Sep 17 00:00:00 2001 From: Entejar Alam Date: Sun, 28 Sep 2025 06:21:51 -0500 Subject: [PATCH 6/9] Added cloglog_ordinal_bart.R function --- R/cloglog_ordinal_bart.R | 180 ++++++ R/cpp11.R | 48 +- include/stochtree/category_tracker.h | 4 + include/stochtree/common.h | 4 + include/stochtree/container.h | 5 + include/stochtree/cutpoint_candidates.h | 2 + include/stochtree/data.h | 1 + include/stochtree/ensemble.h | 6 + include/stochtree/io.h | 2 + include/stochtree/leaf_model.h | 546 ++++++++++--------- include/stochtree/log.h | 2 + include/stochtree/meta.h | 1 + include/stochtree/ordinal_sampler.h | 1 - include/stochtree/partition_tracker.h | 29 +- include/stochtree/random.h | 1 + include/stochtree/random_effects.h | 3 + include/stochtree/slice_sampler.h | 180 ++++++ include/stochtree/tree.h | 3 + include/stochtree/tree_sampler.h | 365 +++++++------ include/stochtree/variance_model.h | 4 + man/bart.Rd | 8 +- man/bcf.Rd | 24 +- man/cloglog_ordinal_bart.Rd | 47 ++ man/createBARTModelFromCombinedJson.Rd | 8 +- man/createBARTModelFromCombinedJsonString.Rd | 8 +- man/createBARTModelFromJson.Rd | 8 +- man/createBARTModelFromJsonFile.Rd | 8 +- man/createBARTModelFromJsonString.Rd | 8 +- man/createBCFModelFromCombinedJson.Rd | 30 +- man/createBCFModelFromCombinedJsonString.Rd | 30 +- man/createBCFModelFromJson.Rd | 34 +- man/createBCFModelFromJsonFile.Rd | 34 +- man/createBCFModelFromJsonString.Rd | 30 +- man/createForestModel.Rd | 8 +- man/getRandomEffectSamples.bartmodel.Rd | 16 +- man/getRandomEffectSamples.bcfmodel.Rd | 34 +- man/predict.bartmodel.Rd | 8 +- man/predict.bcfmodel.Rd | 22 +- man/preprocessPredictionData.Rd | 2 +- man/resetForestModel.Rd | 22 +- man/resetRandomEffectsModel.Rd | 4 +- man/resetRandomEffectsTracker.Rd | 4 +- man/rootResetRandomEffectsModel.Rd | 4 +- man/rootResetRandomEffectsTracker.Rd | 4 +- man/saveBARTModelToJson.Rd | 8 +- man/saveBARTModelToJsonFile.Rd | 8 +- man/saveBARTModelToJsonString.Rd | 8 +- man/saveBCFModelToJson.Rd | 34 +- man/saveBCFModelToJsonFile.Rd | 34 +- man/saveBCFModelToJsonString.Rd | 34 +- src/Makevars.in | 1 + src/Makevars.win.in | 1 + src/R_data.cpp | 1 + src/R_random_effects.cpp | 2 + src/cpp11.cpp | 103 +++- src/cutpoint_candidates.cpp | 1 + src/data.cpp | 1 + src/forest.cpp | 2 + src/io.cpp | 2 + src/kernel.cpp | 2 + src/leaf_model.cpp | 2 + src/ordinal_sampler.cpp | 24 +- src/partition_tracker.cpp | 80 ++- src/py_stochtree.cpp | 19 +- src/sampler.cpp | 143 ++--- src/serialization.cpp | 3 + src/stochtree_types.h | 2 + src/tree.cpp | 4 + 68 files changed, 1503 insertions(+), 808 deletions(-) create mode 100644 R/cloglog_ordinal_bart.R create mode 100644 include/stochtree/slice_sampler.h create mode 100644 man/cloglog_ordinal_bart.Rd diff --git a/R/cloglog_ordinal_bart.R b/R/cloglog_ordinal_bart.R new file mode 100644 index 00000000..9cc9b63a --- /dev/null +++ b/R/cloglog_ordinal_bart.R @@ -0,0 +1,180 @@ +#' Run the BART algorithm for ordinal outcomes using a complementary log-log link +#' +#' @param X A numeric matrix of predictors (training data). +#' @param y A numeric vector of ordinal outcomes (positive integers starting from 1). +#' @param X_test An optional numeric matrix of predictors (test data). +#' @param n_trees Number of trees in the BART ensemble. Default: `50`. +#' @param n_samples_mcmc Total number of MCMC samples to draw. Default: `500`. +#' @param n_burnin Number of burn-in samples to discard. Default: `250`. +#' @param n_thin Thinning interval for MCMC samples. Default: `1`. +#' @param alpha_gamma Shape parameter for the log-gamma prior on cutpoints. Default: `2.0`. +#' @param beta_gamma Rate parameter for the log-gamma prior on cutpoints. Default: `2.0`. +#' @param variable_weights Optional vector of variable weights for splitting (default: equal weights). +#' @param feature_types Optional vector indicating feature types (0 for continuous, 1 for categorical; default: all continuous). + + +cloglog_ordinal_bart <- function(X, y, X_test = NULL, + n_trees = 50, + n_samples_mcmc = 500, + n_burnin = 250, + n_thin = 1, + alpha_gamma = 2.0, + beta_gamma = 2.0, + variable_weights = NULL, + feature_types = NULL, + seed = NULL) { + + # BART parameters + alpha_bart <- 0.95 + beta_bart <- 2 + min_samples_in_leaf <- 5 + max_depth <- 10 + scale_leaf <- 2 / sqrt(n_trees) + cutpoint_grid_size <- 100 # Needed for stochtree:::sample_mcmc_one_iteration_cpp (for GFR), not used in ordinal BART + + # Fixed for identifiability (can be pass as argument later if desired) + gamma_0 = 0.0 # First gamma cutpoint fixed at gamma_0 = 0 + + # Determine whether a test dataset is provided + has_test <- !is.null(X_test) + # Data checks + if (!is.matrix(X)) X <- as.matrix(X) + if (!is.numeric(y)) y <- as.numeric(y) + if (has_test && !is.matrix(X_test)) X_test <- as.matrix(X_test) + + n_samples <- nrow(X) + n_features <- ncol(X) + + if (any(y < 1) || any(y != round(y))) { + stop("Ordinal outcome y must contain positive integers starting from 1") + } + + # Convert from 1-based (R) to 0-based (C++) indexing + ordinal_outcome <- as.integer(y - 1) + n_levels <- max(y) # Number of ordinal categories + + if (n_levels < 2) { + stop("Ordinal outcome must have at least 2 categories") + } + + if (is.null(variable_weights)) { + variable_weights <- rep(1.0, n_features) + } + + if (is.null(feature_types)) { + feature_types <- rep(0L, n_features) + } + + if (!is.null(seed)) { + set.seed(seed) + } + + keep_idx <- seq((n_burnin + 1), n_samples_mcmc, by = n_thin) + n_keep <- length(keep_idx) + + forest_pred_train <- matrix(0, n_samples, n_keep) + if (has_test) { + n_samples_test <- nrow(X_test) + forest_pred_test <- matrix(0, n_samples_test, n_keep) + } + gamma_samples <- matrix(0, n_levels - 1, n_keep) + latent_samples <- matrix(0, n_samples, n_keep) + + # Initialize other model structures as before + dataX <- stochtree::createForestDataset(X) + if (has_test) { + dataXtest <- stochtree::createForestDataset(X_test) + } + outcome_data <- stochtree::createOutcome(as.numeric(ordinal_outcome)) + active_forest <- stochtree::createForest(as.integer(n_trees), 1L, TRUE, FALSE) # Use constant leaves + active_forest$set_root_leaves(0.0) + split_prior <- stochtree:::tree_prior_cpp(alpha_bart, beta_bart, min_samples_in_leaf, max_depth) + forest_samples <- stochtree::createForestSamples(as.integer(n_trees), 1L, TRUE, FALSE) # Use constant leaves + forest_tracker <- stochtree:::forest_tracker_cpp( + dataX$data_ptr, + as.integer(feature_types), + as.integer(n_trees), + as.integer(n_samples) + ) + stochtree:::ordinal_aux_data_initialize_cpp(forest_tracker, as.integer(n_samples), as.integer(n_levels)) + + # Initialize gamma parameters to zero (slot 2) + initial_gamma <- rep(0.0, n_levels - 1) + for (i in seq_along(initial_gamma)) { + stochtree:::ordinal_aux_data_set_cpp(forest_tracker, 2, i - 1, initial_gamma[i]) + } + stochtree:::ordinal_aux_data_update_cumsum_exp_cpp(forest_tracker) + + # Initialize forest predictions slot to zero (slot 1) + for (i in 1:n_samples) { + stochtree:::ordinal_aux_data_set_cpp(forest_tracker, 1, i - 1, 0.0) + } + + ordinal_sampler <- stochtree:::ordinal_sampler_cpp() + rng <- stochtree::createCppRNG(if (is.null(seed)) sample.int(.Machine$integer.max, 1) else seed) + + # Set up sweep indices for tree updates (sample all trees each iteration) + sweep_indices <- as.integer(seq(0, n_trees - 1)) + + sample_counter <- 0 + for (i in 1:n_samples_mcmc) { + keep_sample <- i %in% keep_idx + if (keep_sample) { + sample_counter <- sample_counter + 1 + } + + # 1. Sample forest using MCMC + stochtree:::sample_mcmc_one_iteration_cpp( + dataX$data_ptr, outcome_data$data_ptr, forest_samples$forest_container_ptr, + active_forest$forest_ptr, forest_tracker, split_prior, rng$rng_ptr, + sweep_indices, as.integer(feature_types), as.integer(cutpoint_grid_size), + scale_leaf, variable_weights, alpha_gamma, beta_gamma, 1.0, 4L, keep_sample + ) + + # Set auxiliary data slot 1 to current forest predictions = lambda_hat = sum of all the tree predictions + # This is needed for updating gamma parameters, latent z_i's + forest_pred_current <- active_forest$predict(dataX) + for (j in 1:n_samples) { + stochtree:::ordinal_aux_data_set_cpp(forest_tracker, 1, j - 1, forest_pred_current[j]) + } + + # 2. Sample latent z_i's using truncated exponential + stochtree:::ordinal_sampler_update_latent_variables_cpp( + ordinal_sampler, dataX$data_ptr, outcome_data$data_ptr, forest_tracker, rng$rng_ptr + ) + + # 3. Sample gamma parameters + stochtree:::ordinal_sampler_update_gamma_params_cpp( + ordinal_sampler, dataX$data_ptr, outcome_data$data_ptr, forest_tracker, + alpha_gamma, beta_gamma, gamma_0, rng$rng_ptr + ) + + # 4. Update cumulative sum of exp(gamma) values + stochtree:::ordinal_sampler_update_cumsum_exp_cpp(ordinal_sampler, forest_tracker) + + if (keep_sample) { + forest_pred_train[, sample_counter] <- active_forest$predict(dataX) + if (has_test) { + forest_pred_test[, sample_counter] <- active_forest$predict(dataXtest) + } + gamma_current <- stochtree:::ordinal_aux_data_get_vector_cpp(forest_tracker, 2) + gamma_samples[, sample_counter] <- gamma_current + latent_current <- stochtree:::ordinal_aux_data_get_vector_cpp(forest_tracker, 0) + latent_samples[, sample_counter] <- latent_current + } + } + + result <- list( + forest_predictions_train = forest_pred_train, + forest_predictions_test = if (has_test) forest_pred_test else NULL, + gamma_samples = gamma_samples, + latent_samples = latent_samples, + scale_leaf = scale_leaf, + ordinal_outcome = ordinal_outcome, + n_trees = n_trees, + n_keep = n_keep + ) + + class(result) <- "cloglog_ordinal_bart" + return(result) +} diff --git a/R/cpp11.R b/R/cpp11.R index d77c7472..64db4be1 100644 --- a/R/cpp11.R +++ b/R/cpp11.R @@ -624,12 +624,12 @@ compute_leaf_indices_cpp <- function(forest_container, covariates, forest_nums) .Call(`_stochtree_compute_leaf_indices_cpp`, forest_container, covariates, forest_nums) } -sample_gfr_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, sweep_indices, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, num_features_subsample, num_threads) { - invisible(.Call(`_stochtree_sample_gfr_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, sweep_indices, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, num_features_subsample, num_threads)) +sample_gfr_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, sweep_indices, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, num_features_subsample) { + invisible(.Call(`_stochtree_sample_gfr_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, sweep_indices, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, num_features_subsample)) } -sample_mcmc_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, sweep_indices, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, num_threads) { - invisible(.Call(`_stochtree_sample_mcmc_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, sweep_indices, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, num_threads)) +sample_mcmc_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, sweep_indices, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest) { + invisible(.Call(`_stochtree_sample_mcmc_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, sweep_indices, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest)) } sample_sigma2_one_iteration_cpp <- function(residual, dataset, rng, a, b) { @@ -692,6 +692,46 @@ sample_without_replacement_integer_cpp <- function(population_vector, sampling_p .Call(`_stochtree_sample_without_replacement_integer_cpp`, population_vector, sampling_probs, sample_size) } +ordinal_aux_data_initialize_cpp <- function(tracker_ptr, num_observations, n_levels) { + invisible(.Call(`_stochtree_ordinal_aux_data_initialize_cpp`, tracker_ptr, num_observations, n_levels)) +} + +ordinal_aux_data_get_cpp <- function(tracker_ptr, type_idx, obs_idx) { + .Call(`_stochtree_ordinal_aux_data_get_cpp`, tracker_ptr, type_idx, obs_idx) +} + +ordinal_aux_data_set_cpp <- function(tracker_ptr, type_idx, obs_idx, value) { + invisible(.Call(`_stochtree_ordinal_aux_data_set_cpp`, tracker_ptr, type_idx, obs_idx, value)) +} + +ordinal_aux_data_get_vector_cpp <- function(tracker_ptr, type_idx) { + .Call(`_stochtree_ordinal_aux_data_get_vector_cpp`, tracker_ptr, type_idx) +} + +ordinal_aux_data_set_vector_cpp <- function(tracker_ptr, type_idx, values) { + invisible(.Call(`_stochtree_ordinal_aux_data_set_vector_cpp`, tracker_ptr, type_idx, values)) +} + +ordinal_aux_data_update_cumsum_exp_cpp <- function(tracker_ptr) { + invisible(.Call(`_stochtree_ordinal_aux_data_update_cumsum_exp_cpp`, tracker_ptr)) +} + +ordinal_sampler_cpp <- function() { + .Call(`_stochtree_ordinal_sampler_cpp`) +} + +ordinal_sampler_update_latent_variables_cpp <- function(sampler_ptr, data_ptr, outcome_ptr, tracker_ptr, rng_ptr) { + invisible(.Call(`_stochtree_ordinal_sampler_update_latent_variables_cpp`, sampler_ptr, data_ptr, outcome_ptr, tracker_ptr, rng_ptr)) +} + +ordinal_sampler_update_gamma_params_cpp <- function(sampler_ptr, data_ptr, outcome_ptr, tracker_ptr, alpha_gamma, beta_gamma, gamma_0, rng_ptr) { + invisible(.Call(`_stochtree_ordinal_sampler_update_gamma_params_cpp`, sampler_ptr, data_ptr, outcome_ptr, tracker_ptr, alpha_gamma, beta_gamma, gamma_0, rng_ptr)) +} + +ordinal_sampler_update_cumsum_exp_cpp <- function(sampler_ptr, tracker_ptr) { + invisible(.Call(`_stochtree_ordinal_sampler_update_cumsum_exp_cpp`, sampler_ptr, tracker_ptr)) +} + init_json_cpp <- function() { .Call(`_stochtree_init_json_cpp`) } diff --git a/include/stochtree/category_tracker.h b/include/stochtree/category_tracker.h index 2ce44635..e5817419 100644 --- a/include/stochtree/category_tracker.h +++ b/include/stochtree/category_tracker.h @@ -29,8 +29,12 @@ #include #include +#include #include #include +#include +#include +#include #include namespace StochTree { diff --git a/include/stochtree/common.h b/include/stochtree/common.h index cd57eea2..c7aab3df 100644 --- a/include/stochtree/common.h +++ b/include/stochtree/common.h @@ -8,18 +8,22 @@ #include #include +#include #include #include #include #include #include +#include #include #include #include +#include #include #include #include #include +#include #include #include diff --git a/include/stochtree/container.h b/include/stochtree/container.h index 4b75ef2f..bb0e7849 100644 --- a/include/stochtree/container.h +++ b/include/stochtree/container.h @@ -11,7 +11,12 @@ #include #include +#include +#include #include +#include +#include +#include namespace StochTree { diff --git a/include/stochtree/cutpoint_candidates.h b/include/stochtree/cutpoint_candidates.h index 76f1df4c..8c19013a 100644 --- a/include/stochtree/cutpoint_candidates.h +++ b/include/stochtree/cutpoint_candidates.h @@ -42,6 +42,8 @@ #include #include +#include + namespace StochTree { /*! \brief Computing and tracking cutpoints available for a given feature at a given node diff --git a/include/stochtree/data.h b/include/stochtree/data.h index a6061f4b..df232fb3 100644 --- a/include/stochtree/data.h +++ b/include/stochtree/data.h @@ -9,6 +9,7 @@ #include #include #include +#include namespace StochTree { diff --git a/include/stochtree/ensemble.h b/include/stochtree/ensemble.h index 4f6ddf42..4624b5a4 100644 --- a/include/stochtree/ensemble.h +++ b/include/stochtree/ensemble.h @@ -14,6 +14,12 @@ #include #include +#include +#include +#include +#include +#include + using json = nlohmann::json; namespace StochTree { diff --git a/include/stochtree/io.h b/include/stochtree/io.h index 55963946..3bc277fb 100644 --- a/include/stochtree/io.h +++ b/include/stochtree/io.h @@ -28,10 +28,12 @@ #include #include #include +#include #include #include #include #include +#include #include #include diff --git a/include/stochtree/leaf_model.h b/include/stochtree/leaf_model.h index a563cfe1..6adf9c23 100644 --- a/include/stochtree/leaf_model.h +++ b/include/stochtree/leaf_model.h @@ -16,318 +16,323 @@ #include #include #include +#include #include +#include +#include #include +#include +#include #include namespace StochTree { -/*! +/*! * \defgroup leaf_model_group Leaf Model API - * + * * \brief Classes / functions for implementing leaf models. - * - * Stochastic tree algorithms are all essentially hierarchical - * models with an adaptive group structure defined by an ensemble - * of decision trees. Each novel model is governed by - * + * + * Stochastic tree algorithms are all essentially hierarchical + * models with an adaptive group structure defined by an ensemble + * of decision trees. Each novel model is governed by + * * - A `LeafModel` class, defining the integrated likelihood and posterior, conditional on a particular tree structure * - A `SuffStat` class that tracks and accumulates sufficient statistics necessary for a `LeafModel` - * - * To provide a thorough overview of this interface (and, importantly, how to extend it), we must introduce some mathematical notation. + * + * To provide a thorough overview of this interface (and, importantly, how to extend it), we must introduce some mathematical notation. * Any forest-based regression model involves an outcome, which we'll call \f$y\f$, and features (or "covariates"), which we'll call \f$X\f$. - * Our goal is to predict \f$y\f$ as a function of \f$X\f$, which we'll call \f$f(X)\f$. - * - * NOTE: if we have a more complicated, but still additive, model, such as \f$y = X\beta + f(X)\f$, then we can just model + * Our goal is to predict \f$y\f$ as a function of \f$X\f$, which we'll call \f$f(X)\f$. + * + * NOTE: if we have a more complicated, but still additive, model, such as \f$y = X\beta + f(X)\f$, then we can just model * \f$y - X\beta = f(X)\f$, treating the residual \f$y - X\beta\f$ as the outcome data, and we are back to the general setting above. - * - * Now, since \f$f(X)\f$ is an additive tree ensemble, we can think of it as the sum of \f$b\f$ separate decision tree functions, + * + * Now, since \f$f(X)\f$ is an additive tree ensemble, we can think of it as the sum of \f$b\f$ separate decision tree functions, * where \f$b\f$ is the number of trees in an ensemble, so that - * + * * \f[ * f(X) = f_1(X) + \dots + f_b(X) * \f] - * - * and each decision tree function \f$f_j\f$ has the property that features \f$X\f$ are used to determine which leaf node an observation - * falls into, and then the parameters attached to that leaf node are used to compute \f$f_j(X)\f$. The exact mechanics of this process + * + * and each decision tree function \f$f_j\f$ has the property that features \f$X\f$ are used to determine which leaf node an observation + * falls into, and then the parameters attached to that leaf node are used to compute \f$f_j(X)\f$. The exact mechanics of this process * are model-dependent, so now we introduce the "leaf node" models that `stochtree` supports. * * \section gaussian_constant_leaf_model Gaussian Constant Leaf Model - * + * * The most standard and common tree ensemble is a sum of "constant leaf" trees, in which a leaf node's parameter uniquely determines the prediction - * for all observations that fall into that leaf. For example, if leaf 2 for a tree is reached by the conditions that \f$X_1 < 0.4 \; \& \; X_2 > 0.6\f$, then - * every observation whose first feature is less than 0.4 and whose second feature is greater than 0.6 will receive the same prediction. Mathematically, + * for all observations that fall into that leaf. For example, if leaf 2 for a tree is reached by the conditions that \f$X_1 < 0.4 \; \& \; X_2 > 0.6\f$, then + * every observation whose first feature is less than 0.4 and whose second feature is greater than 0.6 will receive the same prediction. Mathematically, * for an observation \f$i\f$ this looks like - * + * * \f[ * f_j(X_i) = \sum_{\ell \in L} \mathbb{1}(X_i \in \ell) \mu_{\ell} * \f] - * + * * where \f$L\f$ denotes the indices of every leaf node, \f$\mu_{\ell}\f$ is the parameter attached to leaf node \f$\ell\f$, and \f$\mathbb{1}(X \in \ell)\f$ * checks whether \f$X_i\f$ falls into leaf node \f$\ell\f$. - * + * * The way that we make such a model "stochastic" is by attaching to the leaf node parameters \f$\mu_{\ell}\f$ a "prior" distribution. - * This leaf model corresponds to the "classic" BART model of Chipman et al (2010) - * as well as its "XBART" extension (He and Hahn (2023)). + * This leaf model corresponds to the "classic" BART model of Chipman et al (2010) + * as well as its "XBART" extension (He and Hahn (2023)). * We assign each leaf node parameter a prior - * + * * \f[ * \mu \sim N\left(0, \tau\right) * \f] - * - * Assuming a homoskedastic Gaussian outcome likelihood (i.e. \f$y_i \sim N\left(f(X_i),\sigma^2\right)\f$), - * the log marginal likelihood in this model, for the outcome data in node \f$\ell\f$ of tree \f$j\f$ is given by - * + * + * Assuming a homoskedastic Gaussian outcome likelihood (i.e. \f$y_i \sim N\left(f(X_i),\sigma^2\right)\f$), + * the log marginal likelihood in this model, for the outcome data in node \f$\ell\f$ of tree \f$j\f$ is given by + * * \f[ * L(y) = -\frac{n_{\ell}}{2}\log(2\pi) - n_{\ell}\log(\sigma) + \frac{1}{2} \log\left(\frac{\sigma^2}{n_{\ell} \tau + \sigma^2}\right) - \frac{s_{yy,\ell}}{2\sigma^2} + \frac{\tau s_{y,\ell}^2}{2\sigma^2(n_{\ell} \tau + \sigma^2)} * \f] - * + * * where - * + * * \f[ * n_{\ell} = \sum_{i : X_i \in \ell} 1 * \f] - * + * * \f[ * s_{y,\ell} = \sum_{i : X_i \in \ell} r_i * \f] - * + * * \f[ * s_{yy,\ell} = \sum_{i : X_i \in \ell} r_i^2 * \f] - * + * * \f[ * r_i = y_i - \sum_{k \neq j} f_k(X_i) * \f] * - * In words, this model depends on the data for a given leaf node only through three sufficient statistics, \f$n_{\ell}\f$, \f$s_{y,\ell}\f$, and \f$s_{yy,\ell}\f$, - * and it only depends on the other trees in the ensemble through the "partial residual" \f$r_i\f$. The posterior distribution for + * In words, this model depends on the data for a given leaf node only through three sufficient statistics, \f$n_{\ell}\f$, \f$s_{y,\ell}\f$, and \f$s_{yy,\ell}\f$, + * and it only depends on the other trees in the ensemble through the "partial residual" \f$r_i\f$. The posterior distribution for * node \f$\ell\f$'s leaf parameter is similarly defined as: - * + * * \f[ * \mu_{\ell} \mid - \sim N\left(\frac{\tau s_{y,\ell}}{n_{\ell} \tau + \sigma^2}, \frac{\tau \sigma^2}{n_{\ell} \tau + \sigma^2}\right) * \f] - * - * Now, consider the possibility that each observation carries a unique weight \f$w_i\f$. These could be "case weights" in a survey context or + * + * Now, consider the possibility that each observation carries a unique weight \f$w_i\f$. These could be "case weights" in a survey context or * individual-level variances ("heteroskedasticity"). These case weights transform the outcome distribution (and associated likelihood) to - * + * * \f[ - * y_i \mid - \sim N\left(\mu(X_i), \frac{\sigma^2}{w_i}\right) + * y_i \mid - \sim N\left(\mu(X_i), \frac{\sigma^2}{w_i}\right) * \f] - * - * This gives a modified log marginal likelihood of - * + * + * This gives a modified log marginal likelihood of + * * \f[ * L(y) = -\frac{n_{\ell}}{2}\log(2\pi) - \frac{1}{2} \sum_{i : X_i \in \ell} \log\left(\frac{\sigma^2}{w_i}\right) + \frac{1}{2} \log\left(\frac{\sigma^2}{s_{w,\ell} \tau + \sigma^2}\right) - \frac{s_{wyy,\ell}}{2\sigma^2} + \frac{\tau s_{wy,\ell}^2}{2\sigma^2(s_{w,\ell} \tau + \sigma^2)} * \f] - * + * * where - * + * * \f[ * s_{w,\ell} = \sum_{i : X_i \in \ell} w_i * \f] - * + * * \f[ * s_{wy,\ell} = \sum_{i : X_i \in \ell} w_i r_i * \f] - * + * * \f[ * s_{wyy,\ell} = \sum_{i : X_i \in \ell} w_i r_i^2 * \f] - * - * Finally, note that when we consider splitting leaf \f$\ell\f$ into new left and right leaves, or pruning two nodes into a single leaf node, - * we compute the log marginal likelihood of the combined data and the log marginal likelihoods of the left and right leaves and compare these three values. - * - * The terms \f$\frac{n_{\ell}}{2}\log(2\pi)\f$, \f$\sum_{i : X_i \in \ell} \log\left(\frac{\sigma^2}{w_i}\right)\f$, and \f$\frac{s_{wyy,\ell}}{2\sigma^2}\f$ - * are such that their left and right node values will always sum to the respective value in the combined log marginal likelihood, so they can be ignored + * + * Finally, note that when we consider splitting leaf \f$\ell\f$ into new left and right leaves, or pruning two nodes into a single leaf node, + * we compute the log marginal likelihood of the combined data and the log marginal likelihoods of the left and right leaves and compare these three values. + * + * The terms \f$\frac{n_{\ell}}{2}\log(2\pi)\f$, \f$\sum_{i : X_i \in \ell} \log\left(\frac{\sigma^2}{w_i}\right)\f$, and \f$\frac{s_{wyy,\ell}}{2\sigma^2}\f$ + * are such that their left and right node values will always sum to the respective value in the combined log marginal likelihood, so they can be ignored * when evaluating splits or prunes and thus the reduced log marginal likelihood is - * + * * \f[ * L(y) \propto \frac{1}{2} \log\left(\frac{\sigma^2}{s_{w,\ell} \tau + \sigma^2}\right) + \frac{\tau s_{wy,\ell}^2}{2\sigma^2(n_{\ell} \tau + \sigma^2)} * \f] - * + * * So the \ref StochTree::GaussianConstantSuffStat "GaussianConstantSuffStat" class tracks a generalized version of these three statistics * (which allows for each observation to have a weight \f$w_i \neq 1\f$): - * + * * - \f$n_{\ell}\f$: `data_size_t n` * - \f$s_{w,\ell}\f$: `double sum_w` * - \f$s_{wy,\ell}\f$: `double sum_yw` - * - * And these values are used by the \ref StochTree::GaussianConstantLeafModel "GaussianConstantLeafModel" class in the - * \ref StochTree::GaussianConstantLeafModel::SplitLogMarginalLikelihood "SplitLogMarginalLikelihood", - * \ref StochTree::GaussianConstantLeafModel::NoSplitLogMarginalLikelihood "NoSplitLogMarginalLikelihood", - * \ref StochTree::GaussianConstantLeafModel::PosteriorParameterMean "PosteriorParameterMean", and - * \ref StochTree::GaussianConstantLeafModel::PosteriorParameterVariance "PosteriorParameterVariance" methods. + * + * And these values are used by the \ref StochTree::GaussianConstantLeafModel "GaussianConstantLeafModel" class in the + * \ref StochTree::GaussianConstantLeafModel::SplitLogMarginalLikelihood "SplitLogMarginalLikelihood", + * \ref StochTree::GaussianConstantLeafModel::NoSplitLogMarginalLikelihood "NoSplitLogMarginalLikelihood", + * \ref StochTree::GaussianConstantLeafModel::PosteriorParameterMean "PosteriorParameterMean", and + * \ref StochTree::GaussianConstantLeafModel::PosteriorParameterVariance "PosteriorParameterVariance" methods. * To give one example, below is the implementation of \ref StochTree::GaussianConstantLeafModel::SplitLogMarginalLikelihood "SplitLogMarginalLikelihood": - * + * * \code{.cpp} * double left_log_ml = ( * -0.5*std::log(1 + tau_*(left_stat.sum_w/global_variance)) + ((tau_*left_stat.sum_yw*left_stat.sum_yw)/(2.0*global_variance*(tau_*left_stat.sum_w + global_variance))) * ); - * + * * double right_log_ml = ( * -0.5*std::log(1 + tau_*(right_stat.sum_w/global_variance)) + ((tau_*right_stat.sum_yw*right_stat.sum_yw)/(2.0*global_variance*(tau_*right_stat.sum_w + global_variance))) * ); - * + * * return left_log_ml + right_log_ml; - * \endcode - * + * \endcode + * * \section gaussian_multivariate_regression_leaf_model Gaussian Multivariate Regression Leaf Model - * - * In this model, the tree defines a "partitioned linear model" in which leaf node parameters define regression weights + * + * In this model, the tree defines a "partitioned linear model" in which leaf node parameters define regression weights * that are multiplied by a "basis" \f$\Omega\f$ to determine the prediction for an observation. - * + * * \f[ * f_j(X_i) = \sum_{\ell \in L} \mathbb{1}(X_i \in \ell) \Omega_i \vec{\beta_{\ell}} * \f] - * + * * and we assign \f$\beta_{\ell}\f$ a prior of - * + * * \f[ * \vec{\beta_{\ell}} \sim N\left(\vec{\beta_0}, \Sigma_0\right) * \f] - * + * * where \f$\vec{\beta_0}\f$ is typically a vector of zeros. The outcome likelihood is still - * + * * \f[ * y_i \sim N\left(f(X_i), \sigma^2\right) * \f] - * + * * This gives a reduced log integrated likelihood of - * + * * \f[ * L(y) \propto - \frac{1}{2} \log\left(\textrm{det}\left(I_p + \frac{\Sigma_0\Omega'\Omega}{\sigma^2}\right)\right) + \frac{1}{2}\frac{y'\Omega}{\sigma^2}\left(\Sigma_0^{-1} + \frac{\Omega'\Omega}{\sigma^2}\right)^{-1}\frac{\Omega'y}{\sigma^2} * \f] - * - * where \f$\Omega\f$ is a matrix of bases for every observation in leaf \f$\ell\f$ and \f$p\f$ is the dimension of \f$\Omega\f$. The posterior for \f$\vec{\beta_{\ell}}\f$ is - * + * + * where \f$\Omega\f$ is a matrix of bases for every observation in leaf \f$\ell\f$ and \f$p\f$ is the dimension of \f$\Omega\f$. The posterior for \f$\vec{\beta_{\ell}}\f$ is + * * \f[ * \vec{\beta_{\ell}} \sim N\left(\left(\Sigma_0^{-1} + \frac{\Omega'\Omega}{\sigma^2}\right)^{-1}\left(\frac{\Omega'y}{\sigma^2}\right),\left(\Sigma_0^{-1} + \frac{\Omega'\Omega}{\sigma^2}\right)^{-1}\right) * \f] - * + * * This is an extension of the single-tree model of Chipman et al (2002), with: - * + * * - Support for using a separate basis for leaf model than the partitioning (i.e. tree) model (i.e. \f$X \neq \Omega\f$) * - Support for multiple trees and sampling via grow-from-root (GFR) or MCMC - * + * * We can also enable heteroskedasticity by defining a (diagonal) covariance matrix for the outcome likelihood - * + * * \f[ * \Sigma_y = \text{diag}\left(\sigma^2 / w_1,\sigma^2 / w_2,\dots,\sigma^2 / w_n\right) * \f] - * + * * This updates the reduced log integrated likelihood to - * + * * \f[ * L(y) \propto - \frac{1}{2} \log\left(\textrm{det}\left(I_p + \Sigma_{0}\Omega'\Sigma_y^{-1}\Omega\right)\right) + \frac{1}{2}y'\Sigma_{y}^{-1}\Omega\left(\Sigma_{0}^{-1} + \Omega'\Sigma_{y}^{-1}\Omega\right)^{-1}\Omega'\Sigma_{y}^{-1}y * \f] - * - * and a posterior for \f$\vec{\beta_{\ell}}\f$ of - * + * + * and a posterior for \f$\vec{\beta_{\ell}}\f$ of + * * \f[ * \vec{\beta_{\ell}} \sim N\left(\left(\Sigma_{0}^{-1} + \Omega'\Sigma_{y}^{-1}\Omega\right)^{-1}\left(\Omega'\Sigma_{y}^{-1}y\right),\left(\Sigma_{0}^{-1} + \Omega'\Sigma_{y}^{-1}\Omega\right)^{-1}\right) * \f] - * + * * \section gaussian_univariate_regression_leaf_model Gaussian Univariate Regression Leaf Model - * + * * This specializes the Gaussian Multivariate Regression Leaf Model for a univariate leaf basis, which allows for several computational speedups (replacing generalized matrix operations with simple summation or sum-product operations). - * We simplify \f$\Omega\f$ to \f$\omega\f$, a univariate basis for every observation, so that \f$\Omega'\Omega = \sum_{i:i \in \ell}\omega_i^2\f$ and \f$\Omega'y = \sum_{i:i \in \ell}\omega_ir_i\f$. Similarly, the prior for the leaf - * parameter becomes univariate normal as in \ref gaussian_constant_leaf_model: - * + * We simplify \f$\Omega\f$ to \f$\omega\f$, a univariate basis for every observation, so that \f$\Omega'\Omega = \sum_{i:i \in \ell}\omega_i^2\f$ and \f$\Omega'y = \sum_{i:i \in \ell}\omega_ir_i\f$. Similarly, the prior for the leaf + * parameter becomes univariate normal as in \ref gaussian_constant_leaf_model: + * * \f[ * \beta \sim N\left(0, \tau\right) * \f] - * - * Allowing for case / variance weights \f$w_i\f$ as above, we derive a reduced log marginal likelihood of - * + * + * Allowing for case / variance weights \f$w_i\f$ as above, we derive a reduced log marginal likelihood of + * * \f[ * L(y) \propto \frac{1}{2} \log\left(\frac{\sigma^2}{s_{wxx,\ell} \tau + \sigma^2}\right) + \frac{\tau s_{wyx,\ell}^2}{2\sigma^2(s_{wxx,\ell} \tau + \sigma^2)} * \f] - * + * * where - * + * * \f[ * s_{wxx,\ell} = \sum_{i : X_i \in \ell} w_i \omega_i \omega_i * \f] - * + * * \f[ * s_{wyx,\ell} = \sum_{i : X_i \in \ell} w_i r_i \omega_i * \f] - * - * and a posterior of - * + * + * and a posterior of + * * \f[ * \beta_{\ell} \mid - \sim N\left(\frac{\tau s_{wyx,\ell}}{s_{wxx,\ell} \tau + \sigma^2}, \frac{\tau \sigma^2}{s_{wxx,\ell} \tau + \sigma^2}\right) * \f] - * + * * \section inverse_gamma_leaf_model Inverse Gamma Leaf Model - * - * Each of the above models is a variation on a theme: a conjugate, partitioned Gaussian leaf model. + * + * Each of the above models is a variation on a theme: a conjugate, partitioned Gaussian leaf model. * The inverse gamma leaf model allows for forest-based heteroskedasticity modeling using an inverse gamma prior on the exponentiated leaf parameter, as discussed in Murray (2021) * Define a variance function based on an ensemble of \f$b\f$ trees as - * + * * \f[ * \sigma^2(X) = \exp\left(s_1(X) + \dots + s_b(X)\right) * \f] - * - * where each tree function \f$s_j(X)\f$ is defined as - * + * + * where each tree function \f$s_j(X)\f$ is defined as + * * \f[ * s_j(X_i) = \sum_{\ell \in L} \mathbb{1}(X_i \in \ell) \lambda_{\ell} * \f] - * + * * We reparameterize \f$\lambda_{\ell} = \log(\mu_{\ell})\f$ and we place an inverse gamma prior on \f$\mu_{\ell}\f$ - * + * * \f[ * \mu_{\ell} \sim \text{IG}\left(a, b\right) * \f] - * - * As noted in Murray (2021), this model no longer enables the "Bayesian backfitting" simplification - * of conjugated Gaussian leaf models, in which sampling updates for a given tree only depend on other trees in the ensemble via their imprint on the partial residual - * \f$r_i = y_i - \sum_{k \neq j} \mu_k(X_i)\f$. + * + * As noted in Murray (2021), this model no longer enables the "Bayesian backfitting" simplification + * of conjugated Gaussian leaf models, in which sampling updates for a given tree only depend on other trees in the ensemble via their imprint on the partial residual + * \f$r_i = y_i - \sum_{k \neq j} \mu_k(X_i)\f$. * However, this model is part of a broader class of models with convenient "blocked MCMC" sampling updates (another important example being multinomial classification). - * + * * Under an outcome model - * + * * \f[ * y \sim N\left(f(X), \sigma_0^2 \sigma^2(X)\right) * \f] - * + * * updates to \f$\mu_{\ell}\f$ for a given tree \f$j\f$ are based on a reduced log marginal likelihood of - * + * * \f[ * L(y) \propto a \log (b) - \log \Gamma (a) + \log \Gamma \left(a + \frac{n_{\ell}}{2}\right) - \left(a + \frac{n_{\ell}}{2}\right) \left(b + \frac{s_{\sigma,\ell}}{2\sigma_0^2}\right) * \f] - * + * * where - * + * * \f[ * n_{\ell} = \sum_{i : X_i \in \ell} 1 * \f] - * + * * \f[ * s_{\sigma,\ell} = \sum_{i: i \in \ell} \frac{(y_i - f(X_i))^2}{\prod_{k \neq j} s_k(X_i)} * \f] - * - * and a posterior of - * + * + * and a posterior of + * * \f[ * \mu_{\ell} \mid - \sim \text{IG}\left( a + \frac{n_{\ell}}{2} , b + \frac{s_{\sigma,\ell}}{2\sigma_0^2} \right) * \f] - * + * * Thus, as above, we implement a sufficient statistic class (\ref StochTree::LogLinearVarianceSuffStat "LogLinearVarianceSuffStat"), which tracks - * + * * - \f$n_{\ell}\f$: `data_size_t n` * - \f$s_{\sigma,\ell}\f$: `double weighted_sum_ei` - * - * And these values are used by the \ref StochTree::LogLinearVarianceLeafModel "LogLinearVarianceLeafModel" class in the - * \ref StochTree::LogLinearVarianceLeafModel::SplitLogMarginalLikelihood "SplitLogMarginalLikelihood", - * \ref StochTree::LogLinearVarianceLeafModel::NoSplitLogMarginalLikelihood "NoSplitLogMarginalLikelihood", - * \ref StochTree::LogLinearVarianceLeafModel::PosteriorParameterShape "PosteriorParameterShape", and - * \ref StochTree::LogLinearVarianceLeafModel::PosteriorParameterScale "PosteriorParameterScale" methods. + * + * And these values are used by the \ref StochTree::LogLinearVarianceLeafModel "LogLinearVarianceLeafModel" class in the + * \ref StochTree::LogLinearVarianceLeafModel::SplitLogMarginalLikelihood "SplitLogMarginalLikelihood", + * \ref StochTree::LogLinearVarianceLeafModel::NoSplitLogMarginalLikelihood "NoSplitLogMarginalLikelihood", + * \ref StochTree::LogLinearVarianceLeafModel::PosteriorParameterShape "PosteriorParameterShape", and + * \ref StochTree::LogLinearVarianceLeafModel::PosteriorParameterScale "PosteriorParameterScale" methods. * To give one example, below is the implementation of \ref StochTree::LogLinearVarianceLeafModel::NoSplitLogMarginalLikelihood "NoSplitLogMarginalLikelihood": - * + * * \code{.cpp} * double prior_terms = a_ * std::log(b_) - boost::math::lgamma(a_); * double a_term = a_ + 0.5 * suff_stat.n; @@ -337,8 +342,8 @@ namespace StochTree { * double resid_term = a_term * log_b_term; * double log_ml = prior_terms + lgamma_a_term - resid_term; * return log_ml; - * \endcode - * + * \endcode + * * \{ */ @@ -350,9 +355,9 @@ namespace StochTree { * 5. `kCloglogOrdinal`: Every leaf node has a log-gamma prior and every leaf is constant. */ enum ModelType { - kConstantLeafGaussian, - kUnivariateRegressionLeafGaussian, - kMultivariateRegressionLeafGaussian, + kConstantLeafGaussian, + kUnivariateRegressionLeafGaussian, + kMultivariateRegressionLeafGaussian, kLogLinearVariance, kCloglogOrdinal }; @@ -373,7 +378,7 @@ class GaussianConstantSuffStat { } /*! * \brief Accumulate data from observation `row_idx` into the sufficient statistics - * + * * \param dataset Data object containining training data, including covariates, leaf regression bases, and case weights * \param outcome Data object containing the "partial" residual net of all the model's other mean terms, aside from `tree` * \param tracker Tracking data structures that speed up sampler operations, synchronized with `active_forest` tracking a forest's state @@ -398,9 +403,9 @@ class GaussianConstantSuffStat { sum_w = 0.0; sum_yw = 0.0; } - /*! + /*! * \brief Increment the value of each sufficient statistic by the values provided by `suff_stat` - * + * * \param suff_stat Sufficient statistic to be added to the current sufficient statistics */ void AddSuffStatInplace(GaussianConstantSuffStat& suff_stat) { @@ -410,7 +415,7 @@ class GaussianConstantSuffStat { } /*! * \brief Set the value of each sufficient statistic to the sum of the values provided by `lhs` and `rhs` - * + * * \param lhs First sufficient statistic ("left hand side") * \param rhs Second sufficient statistic ("right hand side") */ @@ -421,7 +426,7 @@ class GaussianConstantSuffStat { } /*! * \brief Set the value of each sufficient statistic to the difference between the values provided by `lhs` and those provided by `rhs` - * + * * \param lhs First sufficient statistic ("left hand side") * \param rhs Second sufficient statistic ("right hand side") */ @@ -432,7 +437,7 @@ class GaussianConstantSuffStat { } /*! * \brief Check whether accumulated sample size, `n`, is greater than some threshold - * + * * \param threshold Value used to compute `n > threshold` */ bool SampleGreaterThan(data_size_t threshold) { @@ -440,7 +445,7 @@ class GaussianConstantSuffStat { } /*! * \brief Check whether accumulated sample size, `n`, is greater than or equal to some threshold - * + * * \param threshold Value used to compute `n >= threshold` */ bool SampleGreaterThanEqual(data_size_t threshold) { @@ -459,14 +464,14 @@ class GaussianConstantLeafModel { public: /*! * \brief Construct a new GaussianConstantLeafModel object - * + * * \param tau Leaf node prior scale parameter */ GaussianConstantLeafModel(double tau) {tau_ = tau; normal_sampler_ = UnivariateNormalSampler();} ~GaussianConstantLeafModel() {} /*! * \brief Log marginal likelihood for a proposed split, evaluated only for observations that fall into the node being split. - * + * * \param left_stat Sufficient statistics of the left node formed by the proposed split * \param right_stat Sufficient statistics of the right node formed by the proposed split * \param global_variance Global error variance parameter @@ -474,28 +479,28 @@ class GaussianConstantLeafModel { double SplitLogMarginalLikelihood(GaussianConstantSuffStat& left_stat, GaussianConstantSuffStat& right_stat, double global_variance); /*! * \brief Log marginal likelihood of a node, evaluated only for observations that fall into the node being split. - * + * * \param suff_stat Sufficient statistics of the node being evaluated * \param global_variance Global error variance parameter */ double NoSplitLogMarginalLikelihood(GaussianConstantSuffStat& suff_stat, double global_variance); /*! * \brief Leaf node posterior mean. - * + * * \param suff_stat Sufficient statistics of the node being evaluated * \param global_variance Global error variance parameter */ double PosteriorParameterMean(GaussianConstantSuffStat& suff_stat, double global_variance); /*! * \brief Leaf node posterior variance. - * + * * \param suff_stat Sufficient statistics of the node being evaluated * \param global_variance Global error variance parameter */ double PosteriorParameterVariance(GaussianConstantSuffStat& suff_stat, double global_variance); /*! * \brief Draw new parameters for every leaf node in `tree`, using a Gibbs update that conditions on the data, every other tree in the forest, and all model parameters - * + * * \param dataset Data object containining training data, including covariates, leaf regression bases, and case weights * \param tracker Tracking data structures that speed up sampler operations, synchronized with `active_forest` tracking a forest's state * \param residual Data object containing the "partial" residual net of all the model's other mean terms, aside from `tree` @@ -508,7 +513,7 @@ class GaussianConstantLeafModel { void SetEnsembleRootPredictedValue(ForestDataset& dataset, TreeEnsemble* ensemble, double root_pred_value); /*! * \brief Set a new value for the leaf node scale parameter - * + * * \param tau Leaf node prior scale parameter */ void SetScale(double tau) {tau_ = tau;} @@ -537,7 +542,7 @@ class GaussianUnivariateRegressionSuffStat { } /*! * \brief Accumulate data from observation `row_idx` into the sufficient statistics - * + * * \param dataset Data object containining training data, including covariates, leaf regression bases, and case weights * \param outcome Data object containing the "partial" residual net of all the model's other mean terms, aside from `tree` * \param tracker Tracking data structures that speed up sampler operations, synchronized with `active_forest` tracking a forest's state @@ -562,9 +567,9 @@ class GaussianUnivariateRegressionSuffStat { sum_xxw = 0.0; sum_yxw = 0.0; } - /*! + /*! * \brief Increment the value of each sufficient statistic by the values provided by `suff_stat` - * + * * \param suff_stat Sufficient statistic to be added to the current sufficient statistics */ void AddSuffStatInplace(GaussianUnivariateRegressionSuffStat& suff_stat) { @@ -574,7 +579,7 @@ class GaussianUnivariateRegressionSuffStat { } /*! * \brief Set the value of each sufficient statistic to the sum of the values provided by `lhs` and `rhs` - * + * * \param lhs First sufficient statistic ("left hand side") * \param rhs Second sufficient statistic ("right hand side") */ @@ -585,7 +590,7 @@ class GaussianUnivariateRegressionSuffStat { } /*! * \brief Set the value of each sufficient statistic to the difference between the values provided by `lhs` and those provided by `rhs` - * + * * \param lhs First sufficient statistic ("left hand side") * \param rhs Second sufficient statistic ("right hand side") */ @@ -596,7 +601,7 @@ class GaussianUnivariateRegressionSuffStat { } /*! * \brief Check whether accumulated sample size, `n`, is greater than some threshold - * + * * \param threshold Value used to compute `n > threshold` */ bool SampleGreaterThan(data_size_t threshold) { @@ -604,7 +609,7 @@ class GaussianUnivariateRegressionSuffStat { } /*! * \brief Check whether accumulated sample size, `n`, is greater than or equal to some threshold - * + * * \param threshold Value used to compute `n >= threshold` */ bool SampleGreaterThanEqual(data_size_t threshold) { @@ -625,7 +630,7 @@ class GaussianUnivariateRegressionLeafModel { ~GaussianUnivariateRegressionLeafModel() {} /*! * \brief Log marginal likelihood for a proposed split, evaluated only for observations that fall into the node being split. - * + * * \param left_stat Sufficient statistics of the left node formed by the proposed split * \param right_stat Sufficient statistics of the right node formed by the proposed split * \param global_variance Global error variance parameter @@ -633,28 +638,28 @@ class GaussianUnivariateRegressionLeafModel { double SplitLogMarginalLikelihood(GaussianUnivariateRegressionSuffStat& left_stat, GaussianUnivariateRegressionSuffStat& right_stat, double global_variance); /*! * \brief Log marginal likelihood of a node, evaluated only for observations that fall into the node being split. - * + * * \param suff_stat Sufficient statistics of the node being evaluated * \param global_variance Global error variance parameter */ double NoSplitLogMarginalLikelihood(GaussianUnivariateRegressionSuffStat& suff_stat, double global_variance); /*! * \brief Leaf node posterior mean. - * + * * \param suff_stat Sufficient statistics of the node being evaluated * \param global_variance Global error variance parameter */ double PosteriorParameterMean(GaussianUnivariateRegressionSuffStat& suff_stat, double global_variance); /*! * \brief Leaf node posterior variance. - * + * * \param suff_stat Sufficient statistics of the node being evaluated * \param global_variance Global error variance parameter */ double PosteriorParameterVariance(GaussianUnivariateRegressionSuffStat& suff_stat, double global_variance); /*! * \brief Draw new parameters for every leaf node in `tree`, using a Gibbs update that conditions on the data, every other tree in the forest, and all model parameters - * + * * \param dataset Data object containining training data, including covariates, leaf regression bases, and case weights * \param tracker Tracking data structures that speed up sampler operations, synchronized with `active_forest` tracking a forest's state * \param residual Data object containing the "partial" residual net of all the model's other mean terms, aside from `tree` @@ -681,7 +686,7 @@ class GaussianMultivariateRegressionSuffStat { Eigen::MatrixXd ytWX; /*! * \brief Construct a new GaussianMultivariateRegressionSuffStat object - * + * * \param basis_dim Size of the basis vector that defines the leaf regression */ GaussianMultivariateRegressionSuffStat(int basis_dim) { @@ -692,7 +697,7 @@ class GaussianMultivariateRegressionSuffStat { } /*! * \brief Accumulate data from observation `row_idx` into the sufficient statistics - * + * * \param dataset Data object containining training data, including covariates, leaf regression bases, and case weights * \param outcome Data object containing the "partial" residual net of all the model's other mean terms, aside from `tree` * \param tracker Tracking data structures that speed up sampler operations, synchronized with `active_forest` tracking a forest's state @@ -717,9 +722,9 @@ class GaussianMultivariateRegressionSuffStat { XtWX = Eigen::MatrixXd::Zero(p, p); ytWX = Eigen::MatrixXd::Zero(1, p); } - /*! + /*! * \brief Increment the value of each sufficient statistic by the values provided by `suff_stat` - * + * * \param suff_stat Sufficient statistic to be added to the current sufficient statistics */ void AddSuffStatInplace(GaussianMultivariateRegressionSuffStat& suff_stat) { @@ -729,7 +734,7 @@ class GaussianMultivariateRegressionSuffStat { } /*! * \brief Set the value of each sufficient statistic to the sum of the values provided by `lhs` and `rhs` - * + * * \param lhs First sufficient statistic ("left hand side") * \param rhs Second sufficient statistic ("right hand side") */ @@ -740,7 +745,7 @@ class GaussianMultivariateRegressionSuffStat { } /*! * \brief Set the value of each sufficient statistic to the difference between the values provided by `lhs` and those provided by `rhs` - * + * * \param lhs First sufficient statistic ("left hand side") * \param rhs Second sufficient statistic ("right hand side") */ @@ -751,7 +756,7 @@ class GaussianMultivariateRegressionSuffStat { } /*! * \brief Check whether accumulated sample size, `n`, is greater than some threshold - * + * * \param threshold Value used to compute `n > threshold` */ bool SampleGreaterThan(data_size_t threshold) { @@ -759,7 +764,7 @@ class GaussianMultivariateRegressionSuffStat { } /*! * \brief Check whether accumulated sample size, `n`, is greater than or equal to some threshold - * + * * \param threshold Value used to compute `n >= threshold` */ bool SampleGreaterThanEqual(data_size_t threshold) { @@ -778,14 +783,14 @@ class GaussianMultivariateRegressionLeafModel { public: /*! * \brief Construct a new GaussianMultivariateRegressionLeafModel object - * + * * \param Sigma_0 Prior covariance, must have the same number of rows and columns as dimensions of the basis vector for the multivariate regression problem */ GaussianMultivariateRegressionLeafModel(Eigen::MatrixXd& Sigma_0) {Sigma_0_ = Sigma_0; multivariate_normal_sampler_ = MultivariateNormalSampler();} ~GaussianMultivariateRegressionLeafModel() {} /*! * \brief Log marginal likelihood for a proposed split, evaluated only for observations that fall into the node being split. - * + * * \param left_stat Sufficient statistics of the left node formed by the proposed split * \param right_stat Sufficient statistics of the right node formed by the proposed split * \param global_variance Global error variance parameter @@ -793,28 +798,28 @@ class GaussianMultivariateRegressionLeafModel { double SplitLogMarginalLikelihood(GaussianMultivariateRegressionSuffStat& left_stat, GaussianMultivariateRegressionSuffStat& right_stat, double global_variance); /*! * \brief Log marginal likelihood of a node, evaluated only for observations that fall into the node being split. - * + * * \param suff_stat Sufficient statistics of the node being evaluated * \param global_variance Global error variance parameter */ double NoSplitLogMarginalLikelihood(GaussianMultivariateRegressionSuffStat& suff_stat, double global_variance); /*! * \brief Leaf node posterior mean. - * + * * \param suff_stat Sufficient statistics of the node being evaluated * \param global_variance Global error variance parameter */ Eigen::VectorXd PosteriorParameterMean(GaussianMultivariateRegressionSuffStat& suff_stat, double global_variance); /*! * \brief Leaf node posterior variance. - * + * * \param suff_stat Sufficient statistics of the node being evaluated * \param global_variance Global error variance parameter */ Eigen::MatrixXd PosteriorParameterVariance(GaussianMultivariateRegressionSuffStat& suff_stat, double global_variance); /*! * \brief Draw new parameters for every leaf node in `tree`, using a Gibbs update that conditions on the data, every other tree in the forest, and all model parameters - * + * * \param dataset Data object containining training data, including covariates, leaf regression bases, and case weights * \param tracker Tracking data structures that speed up sampler operations, synchronized with `active_forest` tracking a forest's state * \param residual Data object containing the "partial" residual net of all the model's other mean terms, aside from `tree` @@ -843,7 +848,7 @@ class LogLinearVarianceSuffStat { } /*! * \brief Accumulate data from observation `row_idx` into the sufficient statistics - * + * * \param dataset Data object containining training data, including covariates, leaf regression bases, and case weights * \param outcome Data object containing the "partial" residual net of all the model's other mean terms, aside from `tree` * \param tracker Tracking data structures that speed up sampler operations, synchronized with `active_forest` tracking a forest's state @@ -863,7 +868,7 @@ class LogLinearVarianceSuffStat { } /*! * \brief Increment the value of each sufficient statistic by the values provided by `suff_stat` - * + * * \param suff_stat Sufficient statistic to be added to the current sufficient statistics */ void AddSuffStatInplace(LogLinearVarianceSuffStat& suff_stat) { @@ -872,7 +877,7 @@ class LogLinearVarianceSuffStat { } /*! * \brief Set the value of each sufficient statistic to the sum of the values provided by `lhs` and `rhs` - * + * * \param lhs First sufficient statistic ("left hand side") * \param rhs Second sufficient statistic ("right hand side") */ @@ -882,7 +887,7 @@ class LogLinearVarianceSuffStat { } /*! * \brief Set the value of each sufficient statistic to the difference between the values provided by `lhs` and those provided by `rhs` - * + * * \param lhs First sufficient statistic ("left hand side") * \param rhs Second sufficient statistic ("right hand side") */ @@ -892,7 +897,7 @@ class LogLinearVarianceSuffStat { } /*! * \brief Check whether accumulated sample size, `n`, is greater than some threshold - * + * * \param threshold Value used to compute `n > threshold` */ bool SampleGreaterThan(data_size_t threshold) { @@ -900,7 +905,7 @@ class LogLinearVarianceSuffStat { } /*! * \brief Check whether accumulated sample size, `n`, is greater than or equal to some threshold - * + * * \param threshold Value used to compute `n >= threshold` */ bool SampleGreaterThanEqual(data_size_t threshold) { @@ -921,7 +926,7 @@ class LogLinearVarianceLeafModel { ~LogLinearVarianceLeafModel() {} /*! * \brief Log marginal likelihood for a proposed split, evaluated only for observations that fall into the node being split. - * + * * \param left_stat Sufficient statistics of the left node formed by the proposed split * \param right_stat Sufficient statistics of the right node formed by the proposed split * \param global_variance Global error variance parameter @@ -929,7 +934,7 @@ class LogLinearVarianceLeafModel { double SplitLogMarginalLikelihood(LogLinearVarianceSuffStat& left_stat, LogLinearVarianceSuffStat& right_stat, double global_variance); /*! * \brief Log marginal likelihood of a node, evaluated only for observations that fall into the node being split. - * + * * \param suff_stat Sufficient statistics of the node being evaluated * \param global_variance Global error variance parameter */ @@ -937,21 +942,21 @@ class LogLinearVarianceLeafModel { double SuffStatLogMarginalLikelihood(LogLinearVarianceSuffStat& suff_stat, double global_variance); /*! * \brief Leaf node posterior shape parameter. - * + * * \param suff_stat Sufficient statistics of the node being evaluated * \param global_variance Global error variance parameter */ double PosteriorParameterShape(LogLinearVarianceSuffStat& suff_stat, double global_variance); /*! * \brief Leaf node posterior scale parameter. - * + * * \param suff_stat Sufficient statistics of the node being evaluated * \param global_variance Global error variance parameter */ double PosteriorParameterScale(LogLinearVarianceSuffStat& suff_stat, double global_variance); /*! * \brief Draw new parameters for every leaf node in `tree`, using a Gibbs update that conditions on the data, every other tree in the forest, and all model parameters - * + * * \param dataset Data object containining training data, including covariates, leaf regression bases, and case weights * \param tracker Tracking data structures that speed up sampler operations, synchronized with `active_forest` tracking a forest's state * \param residual Data object containing the "full" residual net of all the model's mean terms @@ -971,13 +976,14 @@ class LogLinearVarianceLeafModel { GammaSampler gamma_sampler_; }; + /*! \brief Sufficient statistic and associated operations for complementary log-log ordinal BART model */ class CloglogOrdinalSuffStat { public: data_size_t n; double sum_Y_less_K; double other_sum; - + /*! * \brief Construct a new CloglogOrdinalSuffStat object, setting all sufficient statistics to zero */ @@ -986,10 +992,10 @@ class CloglogOrdinalSuffStat { sum_Y_less_K = 0.0; other_sum = 0.0; } - + /*! * \brief Accumulate data from observation `row_idx` into the sufficient statistics - * + * * \param dataset Data object containing training data, including covariates * \param outcome Data object containing the original ordinal outcome values, which are used to compute sufficient statistics * \param tracker Tracking data structures that speed up sampler operations, synchronized with `active_forest` tracking a forest's state @@ -998,20 +1004,20 @@ class CloglogOrdinalSuffStat { */ void IncrementSuffStat(ForestDataset& dataset, Eigen::VectorXd& outcome, ForestTracker& tracker, data_size_t row_idx, int tree_idx) { n += 1; - + // Get ordinal outcome value for this observation unsigned int y = static_cast(outcome(row_idx)); - + // Get auxiliary data from tracker (assuming types: 0=latents Z, 1=forest predictions, 2=cutpoints gamma, 3=cumsum exp of gamma) double Z = tracker.GetOrdinalAuxData(0, row_idx); // latent variables Z - double lambda_minus = tracker.GetOrdinalAuxData(1, row_idx); // forest predictions excluding current tree + double lambda_minus = tracker.GetOrdinalAuxData(1, row_idx); // forest predictions excluding current tree // Get cutpoints gamma and cumulative sum of exp(gamma) const std::vector& gamma = tracker.GetOrdinalAuxDataVector(2); // cutpoints gamma const std::vector& seg = tracker.GetOrdinalAuxDataVector(3); // cumsum exp of gamma int K = gamma.size() + 1; // Number of ordinal categories - + if (y == K - 1) { other_sum += std::exp(lambda_minus) * seg[y]; // checked and it's correct } else { @@ -1019,7 +1025,7 @@ class CloglogOrdinalSuffStat { other_sum += std::exp(lambda_minus) * (Z * std::exp(gamma[y]) + seg[y]); // checked and it's correct } } - + /*! * \brief Reset all of the sufficient statistics to zero */ @@ -1028,10 +1034,21 @@ class CloglogOrdinalSuffStat { sum_Y_less_K = 0.0; other_sum = 0.0; } - + + /*! + * \brief Increment the value of each sufficient statistic by the values provided by `suff_stat` + * + * \param suff_stat Sufficient statistic to be added to the current sufficient statistics + */ + void AddSuffStatInplace(CloglogOrdinalSuffStat& suff_stat) { + n += suff_stat.n; + sum_Y_less_K += suff_stat.sum_Y_less_K; + other_sum += suff_stat.other_sum; + } + /*! * \brief Set the value of each sufficient statistic to the sum of the values provided by `lhs` and `rhs` - * + * * \param lhs First sufficient statistic ("left hand side") * \param rhs Second sufficient statistic ("right hand side") */ @@ -1040,10 +1057,10 @@ class CloglogOrdinalSuffStat { sum_Y_less_K = lhs.sum_Y_less_K + rhs.sum_Y_less_K; other_sum = lhs.other_sum + rhs.other_sum; } - + /*! * \brief Set the value of each sufficient statistic to the difference between the values provided by `lhs` and those provided by `rhs` - * + * * \param lhs First sufficient statistic ("left hand side") * \param rhs Second sufficient statistic ("right hand side") */ @@ -1052,25 +1069,25 @@ class CloglogOrdinalSuffStat { sum_Y_less_K = lhs.sum_Y_less_K - rhs.sum_Y_less_K; other_sum = lhs.other_sum - rhs.other_sum; } - + /*! * \brief Check whether accumulated sample size, `n`, is greater than some threshold - * + * * \param threshold Value used to compute `n > threshold` */ bool SampleGreaterThan(data_size_t threshold) { return n > threshold; } - + /*! * \brief Check whether accumulated sample size, `n`, is greater than or equal to some threshold - * + * * \param threshold Value used to compute `n >= threshold` */ bool SampleGreaterThanEqual(data_size_t threshold) { return n >= threshold; } - + /*! * \brief Return the sample size accumulated by a sufficient stat object */ @@ -1084,8 +1101,8 @@ class CloglogOrdinalLeafModel { public: /*! * \brief Construct a new CloglogOrdinalLeafModel object - * - * \param a Shape parameter for log-gamma prior on leaf parameters + * + * \param a shape parameter for log-gamma prior on leaf parameters * \param b rate parameter for log-gamma prior on leaf parameters * Log-gamma density: f(x) = b^a / Gamma(a) * exp(a*x - b*exp(x)) * Relationship to tau (scale of leaf parameters): tau^2 = trigamma(a) @@ -1094,6 +1111,7 @@ class CloglogOrdinalLeafModel { a_ = a; b_ = b; gamma_sampler_ = GammaSampler(); + // slice_sampler_ = SliceSampler(); tau_ = std::sqrt(boost::math::trigamma(a_)); } ~CloglogOrdinalLeafModel() {} @@ -1139,9 +1157,39 @@ class CloglogOrdinalLeafModel { inline bool RequiresBasis() {return false;} + // /*! + // * \brief Update the scale parameter (tau_) using slice sampling + // * + // * \param lambda Vector of leaf parameter values from all trees + // * \param scale_sigma_lambda Prior scale parameter for scale parameter (tau_) of leaf parameters + // * \param gen Random number generator + // */ + // void UpdateScaleLambda(const std::vector& lambda, double scale_sigma_lambda, std::mt19937& gen) { + // double n = static_cast(lambda.size()); + // double sum_lambda = 0.0; + // double sum_exp_lambda = 0.0; + + // for (size_t i = 0; i < lambda.size(); i++) { + // sum_lambda += lambda[i]; + // sum_exp_lambda += std::exp(lambda[i]); + // } + + // // Create log-likelihood function + // ScaleLambdaLoglik loglik_func(n, sum_lambda, sum_exp_lambda, scale_sigma_lambda); + + // // Sample new scale parameter using slice sampling + // double current_tau = tau_; + // double w = 1.0; // Step size for slice sampler + // double lower = 1e-6; // Lower bound for tau + // double upper = std::numeric_limits::infinity(); // Upper bound + + // double new_tau = slice_sampler_.Sample(current_tau, &loglik_func, w, lower, upper, gen); + // tau_ = new_tau; + // } + /*! * \brief Convert tau_ (scale_lambda i.e. scale for leaf parameters) to alpha (shape) and beta (rate) parameters for the log-gamma prior - * + * * \param alpha Output: shape parameter for log-gamma prior * \param beta Output: rate parameter for log-gamma prior * \param tau Scale parameter (tau_) for leaf parameters @@ -1155,7 +1203,7 @@ class CloglogOrdinalLeafModel { /*! * \brief Convert alpha (shape) and beta (rate) parameters (for the log-gamma prior) back to tau_ (scale_lambda i.e. scale for leaf parameters) - * + * * \param alpha Shape parameter for log-gamma prior * \param beta Rate parameter for log-gamma prior * \return tau Scale parameter (tau_) for leaf parameters @@ -1169,9 +1217,9 @@ class CloglogOrdinalLeafModel { private: /*! * \brief Compute inverse trigamma function using Newton's method - * + * * Implementation adapted from limma package in R, originally by Gordon Smyth - * + * * \param x Input value for which to compute trigamma inverse * \return Value y such that trigamma(y) = x */ @@ -1198,34 +1246,34 @@ class CloglogOrdinalLeafModel { double a_; double b_; GammaSampler gamma_sampler_; + // SliceSampler slice_sampler_; double tau_; }; -/*! - * \brief Unifying layer for disparate sufficient statistic class types - * - * Joins together GaussianConstantSuffStat, GaussianUnivariateRegressionSuffStat, - * GaussianMultivariateRegressionSuffStat, and LogLinearVarianceSuffStat - * as a combined "variant" type. See the std::variant documentation +/*! \brief Unifying layer for disparate sufficient statistic class types + * + * Joins together GaussianConstantSuffStat, GaussianUnivariateRegressionSuffStat, + * GaussianMultivariateRegressionSuffStat, LogLinearVarianceSuffStat, and CloglogOrdinalSuffStat + * as a combined "variant" type. See the std::variant documentation * for more detail. */ -using SuffStatVariant = std::variant; /*! * \brief Unifying layer for disparate leaf model class types - * - * Joins together GaussianConstantLeafModel, GaussianUnivariateRegressionLeafModel, - * GaussianMultivariateRegressionLeafModel, and LogLinearVarianceLeafModel - * as a combined "variant" type. See the std::variant documentation + * + * Joins together GaussianConstantLeafModel, GaussianUnivariateRegressionLeafModel, + * GaussianMultivariateRegressionLeafModel, LogLinearVarianceLeafModel, and CloglogOrdinalLeafModel + * as a combined "variant" type. See the std::variant documentation * for more detail. */ -using LeafModelVariant = std::variant; @@ -1241,7 +1289,7 @@ static inline LeafModelVariant createLeafModel(LeafModelConstructorArgs... leaf_ /*! * \brief Factory function that creates a new `SuffStat` object for the specified model type - * + * * \param model_type Enumeration storing the model type * \param basis_dim [Optional] dimension of the basis vector, only used if `model_type = kMultivariateRegressionLeafGaussian` */ @@ -1261,16 +1309,14 @@ static inline SuffStatVariant suffStatFactory(ModelType model_type, int basis_di /*! * \brief Factory function that creates a new `LeafModel` object for the specified model type - * + * * \param model_type Enumeration storing the model type - * \param tau Value of the leaf node prior scale parameter, only used if `model_type = kConstantLeafGaussian` or `model_type = kUnivariateRegressionLeafGaussian` + * \param tau Value of the leaf node prior scale parameter, only used if `model_type = kConstantLeafGaussian`, `model_type = kUnivariateRegressionLeafGaussian` * \param Sigma0 Value of the leaf node prior covariance matrix, only used if `model_type = kMultivariateRegressionLeafGaussian` - * \param a Value of the leaf node inverse gamma prior shape parameter, only used if `model_type = kLogLinearVariance` - * \param b Value of the leaf node inverse gamma prior scale parameter, only used if `model_type = kLogLinearVariance` - * \param c Value of the leaf node log-gamma prior shape parameter, only used if `model_type = kCloglogOrdinal` - * \param d Value of the leaf node log-gamma prior rate parameter, only used if `model_type = kCloglogOrdinal` + * \param a Value of the leaf node inverse gamma prior shape parameter, only used if `model_type = kLogLinearVariance` (or value of the leaf node log-gamma prior shape parameter, only used if `model_type = kCloglogOrdinal`) + * \param b Value of the leaf node inverse gamma prior scale parameter, only used if `model_type = kLogLinearVariance` (or value of the leaf node log-gamma prior rate parameter, only used if `model_type = kCloglogOrdinal`) */ -static inline LeafModelVariant leafModelFactory(ModelType model_type, double tau, Eigen::MatrixXd& Sigma0, double a, double b, double c, double d) { +static inline LeafModelVariant leafModelFactory(ModelType model_type, double tau, Eigen::MatrixXd& Sigma0, double a, double b) { if (model_type == kConstantLeafGaussian) { return createLeafModel(tau); } else if (model_type == kUnivariateRegressionLeafGaussian) { @@ -1280,14 +1326,14 @@ static inline LeafModelVariant leafModelFactory(ModelType model_type, double tau } else if (model_type == kLogLinearVariance) { return createLeafModel(a, b); } else { - return createLeafModel(c, d); + return createLeafModel(a, b); } } template static inline void AccumulateSuffStatProposed( - SuffStatType& node_suff_stat, SuffStatType& left_suff_stat, SuffStatType& right_suff_stat, ForestDataset& dataset, ForestTracker& tracker, - ColumnVector& residual, double global_variance, TreeSplit& split, int tree_num, int leaf_num, int split_feature, int num_threads, + SuffStatType& node_suff_stat, SuffStatType& left_suff_stat, SuffStatType& right_suff_stat, ForestDataset& dataset, ForestTracker& tracker, + ColumnVector& residual, double global_variance, TreeSplit& split, int tree_num, int leaf_num, int split_feature, int num_threads, SuffStatConstructorArgs&... suff_stat_args ) { // Determine the position of the node's indices in the forest tracking data structure @@ -1312,13 +1358,13 @@ static inline void AccumulateSuffStatProposed( std::vector thread_suff_stats_left; std::vector thread_suff_stats_right; for (int i = 0; i < num_threads; i++) { - thread_ranges[i] = std::make_pair(node_begin_index + i * chunk_size, + thread_ranges[i] = std::make_pair(node_begin_index + i * chunk_size, node_begin_index + (i + 1) * chunk_size); thread_suff_stats_node.emplace_back(suff_stat_args...); thread_suff_stats_left.emplace_back(suff_stat_args...); thread_suff_stats_right.emplace_back(suff_stat_args...); } - + // Accumulate sufficient statistics StochTree::ParallelFor(0, num_threads, num_threads, [&](int i) { int start_idx = thread_ranges[i].first; @@ -1356,7 +1402,7 @@ static inline void AccumulateSuffStatProposed( } template -static inline void AccumulateSuffStatExisting(SuffStatType& node_suff_stat, SuffStatType& left_suff_stat, SuffStatType& right_suff_stat, ForestDataset& dataset, ForestTracker& tracker, +static inline void AccumulateSuffStatExisting(SuffStatType& node_suff_stat, SuffStatType& left_suff_stat, SuffStatType& right_suff_stat, ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, double global_variance, int tree_num, int split_node_id, int left_node_id, int right_node_id) { // Acquire iterators auto left_node_begin_iter = tracker.UnsortedNodeBeginIterator(tree_num, left_node_id); @@ -1392,7 +1438,7 @@ static inline void AccumulateSingleNodeSuffStat(SuffStatType& node_suff_stat, Fo node_begin_iter = tracker.UnsortedNodeBeginIterator(tree_num, node_id); node_end_iter = tracker.UnsortedNodeEndIterator(tree_num, node_id); } - + // Accumulate sufficient statistics for (auto i = node_begin_iter; i != node_end_iter; i++) { auto idx = *i; @@ -1401,13 +1447,13 @@ static inline void AccumulateSingleNodeSuffStat(SuffStatType& node_suff_stat, Fo } template -static inline void AccumulateCutpointBinSuffStat(SuffStatType& left_suff_stat, ForestTracker& tracker, CutpointGridContainer& cutpoint_grid_container, - ForestDataset& dataset, ColumnVector& residual, double global_variance, int tree_num, int node_id, +static inline void AccumulateCutpointBinSuffStat(SuffStatType& left_suff_stat, ForestTracker& tracker, CutpointGridContainer& cutpoint_grid_container, + ForestDataset& dataset, ColumnVector& residual, double global_variance, int tree_num, int node_id, int feature_num, int cutpoint_num) { // Acquire iterators auto node_begin_iter = tracker.SortedNodeBeginIterator(node_id, feature_num); auto node_end_iter = tracker.SortedNodeEndIterator(node_id, feature_num); - + // Determine node start point data_size_t node_begin = tracker.SortedNodeBegin(node_id, feature_num); diff --git a/include/stochtree/log.h b/include/stochtree/log.h index 9f64c31b..3a4c5600 100644 --- a/include/stochtree/log.h +++ b/include/stochtree/log.h @@ -15,6 +15,8 @@ #include #include #include +#include +#include #include #include diff --git a/include/stochtree/meta.h b/include/stochtree/meta.h index d0aa4049..991c254f 100644 --- a/include/stochtree/meta.h +++ b/include/stochtree/meta.h @@ -14,6 +14,7 @@ #include #include #include +#include #include #include #include diff --git a/include/stochtree/ordinal_sampler.h b/include/stochtree/ordinal_sampler.h index ec148b5b..054a14c7 100644 --- a/include/stochtree/ordinal_sampler.h +++ b/include/stochtree/ordinal_sampler.h @@ -43,7 +43,6 @@ class OrdinalSampler { */ static double SampleTruncatedExponential(double lambda, std::mt19937& gen); - /*! * \brief Update truncated exponential latent variables (Z) * diff --git a/include/stochtree/partition_tracker.h b/include/stochtree/partition_tracker.h index a2f5dd70..3f342f15 100644 --- a/include/stochtree/partition_tracker.h +++ b/include/stochtree/partition_tracker.h @@ -31,7 +31,11 @@ #include #include +#include #include +#include +#include +#include #include namespace StochTree { @@ -104,6 +108,7 @@ class ForestTracker { void SetOrdinalAuxData(int type_idx, data_size_t obs_idx, double value); std::vector& GetOrdinalAuxDataVector(int type_idx); + private: /*! \brief Mapper from observations to predicted values summed over every tree in a forest */ std::vector sum_predictions_; @@ -132,7 +137,7 @@ class ForestTracker { void UpdateSampleTrackersInternal(TreeEnsemble& forest, Eigen::MatrixXd& covariates); void UpdateSampleTrackersResidualInternalBasis(TreeEnsemble& forest, ForestDataset& dataset, ColumnVector& residual, bool is_mean_model); void UpdateSampleTrackersResidualInternalNoBasis(TreeEnsemble& forest, ForestDataset& dataset, ColumnVector& residual, bool is_mean_model); - + /*! * \brief Track auxiliary data for cloglog ordinal bart models * Vector of vectors to store these auxiliary data @@ -146,6 +151,8 @@ class ForestTracker { * type_idx is the index of the type of auxiliary data (0: latent Z, 1: forest predictions, 2: cutpoints gamma, 3: cumsum exp of cutpoints) */ void ResizeOrdinalAuxData(data_size_t num_observations, int n_levels); + // bool IsValidOrdinalType(int type_idx) const; + // bool IsValidOrdinalIndex(int type_idx, data_size_t obs_idx) const; }; /*! \brief Class storing sample-prediction map for each tree in an ensemble */ @@ -456,7 +463,7 @@ class UnsortedNodeSampleTracker { /*! \brief Number of trees */ int NumTrees() { return num_trees_; } - /*! \brief Return a pointer to the feature partition tracking tree i */ + /*! \brief Number of trees */ FeatureUnsortedPartition* GetFeaturePartition(int i) { return feature_partitions_[i].get(); } private: @@ -637,24 +644,24 @@ class SortedNodeSampleTracker { } /*! \brief Partition a node based on a new split rule */ - void PartitionNode(Eigen::MatrixXd& covariates, int node_id, int feature_split, TreeSplit& split, int num_threads = -1) { - StochTree::ParallelFor(0, num_features_, num_threads, [&](int i) { + void PartitionNode(Eigen::MatrixXd& covariates, int node_id, int feature_split, TreeSplit& split) { + for (int i = 0; i < num_features_; i++) { feature_partitions_[i]->SplitFeature(covariates, node_id, feature_split, split); - }); + } } /*! \brief Partition a node based on a new split rule */ - void PartitionNode(Eigen::MatrixXd& covariates, int node_id, int feature_split, double split_value, int num_threads = -1) { - StochTree::ParallelFor(0, num_features_, num_threads, [&](int i) { + void PartitionNode(Eigen::MatrixXd& covariates, int node_id, int feature_split, double split_value) { + for (int i = 0; i < num_features_; i++) { feature_partitions_[i]->SplitFeatureNumeric(covariates, node_id, feature_split, split_value); - }); + } } /*! \brief Partition a node based on a new split rule */ - void PartitionNode(Eigen::MatrixXd& covariates, int node_id, int feature_split, std::vector const& category_list, int num_threads = -1) { - StochTree::ParallelFor(0, num_features_, num_threads, [&](int i) { + void PartitionNode(Eigen::MatrixXd& covariates, int node_id, int feature_split, std::vector const& category_list) { + for (int i = 0; i < num_features_; i++) { feature_partitions_[i]->SplitFeatureCategorical(covariates, node_id, feature_split, category_list); - }); + } } /*! \brief First index of data points contained in node_id */ diff --git a/include/stochtree/random.h b/include/stochtree/random.h index 3d39b647..a841f396 100644 --- a/include/stochtree/random.h +++ b/include/stochtree/random.h @@ -5,6 +5,7 @@ #ifndef STOCHTREE_RANDOM_H_ #define STOCHTREE_RANDOM_H_ +#include #include #include #include diff --git a/include/stochtree/random_effects.h b/include/stochtree/random_effects.h index b322a560..701ebeaa 100644 --- a/include/stochtree/random_effects.h +++ b/include/stochtree/random_effects.h @@ -17,11 +17,14 @@ #include #include +#include #include #include #include #include +#include #include +#include #include namespace StochTree { diff --git a/include/stochtree/slice_sampler.h b/include/stochtree/slice_sampler.h new file mode 100644 index 00000000..07fe5a26 --- /dev/null +++ b/include/stochtree/slice_sampler.h @@ -0,0 +1,180 @@ +/*! + * Copyright (c) 2024 stochtree authors. All rights reserved. + * Licensed under the MIT License. See LICENSE file in the project root for license information. + */ +#ifndef STOCHTREE_SLICE_SAMPLER_H_ +#define STOCHTREE_SLICE_SAMPLER_H_ + +#include +#include +#include +#include +#include + +#ifndef M_LN2 +#define M_LN2 0.6931471805599453 // ln(2) +#endif + +namespace StochTree { + +/*! + * \brief Abstract base class for log-likelihood functions used in slice sampling + */ +class LoglikFunction { + public: + virtual ~LoglikFunction() {} + + /*! + * \brief Evaluate the log-likelihood function at point x + * \param x Input value + * \return Log-likelihood value + */ + virtual double Evaluate(double x) = 0; +}; + +/*! + * \brief Log-likelihood function for scale_lambda parameter in ordinal models + */ +class ScaleLambdaLoglik : public LoglikFunction { + public: + /*! + * \brief Constructor for scale lambda log-likelihood + * \param n Number of observations (lambda values) + * \param sum_lambda Sum of all lambda values + * \param sum_exp_lambda Sum of exp(lambda) values + * \param scale Prior scale parameter for scale_lambda + */ + ScaleLambdaLoglik(double n, double sum_lambda, double sum_exp_lambda, double scale) + : n_(n), sum_lambda_(sum_lambda), sum_exp_lambda_(sum_exp_lambda), scale_(scale) {} + + /*! + * \brief Evaluate log-likelihood of scale_lambda parameter + * \param sigma Input scale parameter value (scale_lambda) + * \return Log-likelihood value + */ + double Evaluate(double sigma) override { + if (sigma <= 0) return -std::numeric_limits::infinity(); + + // Convert scale_lambda to alpha and beta parameters + double alpha, beta; + ScaleLambdaToAlphaBeta(alpha, beta, sigma); + + // Log-likelihood contribution from lambda values (gamma prior) + double loglik = n_ * alpha * std::log(beta) + - n_ * boost::math::lgamma(alpha) + + alpha * sum_lambda_ + - beta * sum_exp_lambda_; + + // Add constants and prior terms + loglik += M_LN2 - 0.5 * std::log(2.0 * M_PI); // M_LN2 - LN_2_BY_PI approximation + + // Prior on scale_lambda (half-normal or similar) + double scale_ratio = sigma / scale_; + loglik -= 0.5 * scale_ratio * scale_ratio; + + return loglik; + } + + private: + double n_; + double sum_lambda_; + double sum_exp_lambda_; + double scale_; + + /*! + * \brief Convert scale_lambda to alpha and beta parameters for the gamma prior + */ + void ScaleLambdaToAlphaBeta(double& alpha, double& beta, const double sigma) { + double sigma_sq = sigma * sigma; + alpha = TrigammaInverse(sigma_sq); + beta = std::exp(boost::math::digamma(alpha)); + } + + /*! + * \brief Compute inverse trigamma function using Newton's method + */ + double TrigammaInverse(double x) { + if (x > 1E7) return 1.0 / std::sqrt(x); + if (x < 1E-6) return 1.0 / x; + + double y = 0.5 + 1.0 / x; + for (int i = 0; i < 50; i++) { + double tri = boost::math::trigamma(y); + double dif = tri * (1.0 - tri / x) / boost::math::polygamma(3, y); + y += dif; + if (-dif / y < 1E-8) break; + } + return y; + } +}; + +/*! + * \brief Slice sampler implementation + */ +class SliceSampler { + public: + SliceSampler() {} + ~SliceSampler() {} + + /*! + * \brief Sample from a distribution using slice sampling + * \param x0 Initial value + * \param loglik_func Log-likelihood function + * \param w Step size for expanding interval + * \param lower Lower bound + * \param upper Upper bound + * \param gen Random number generator + * \return Sampled value + */ + double Sample(double x0, LoglikFunction* loglik_func, double w, + double lower, double upper, std::mt19937& gen) { + + std::uniform_real_distribution unif(0.0, 1.0); + std::exponential_distribution exp_dist(1.0); + + // Find the log density at the initial point + double gx0 = loglik_func->Evaluate(x0); + + // Determine the slice level, in log terms + double logy = gx0 - exp_dist(gen); + + // Find the initial interval to sample from + double u = w * unif(gen); + double L = x0 - u; + double R = x0 + (w - u); + + // Expand the interval until its ends are outside the slice + while (L > lower && loglik_func->Evaluate(L) > logy) { + L -= w; + } + + while (R < upper && loglik_func->Evaluate(R) > logy) { + R += w; + } + + // Shrink interval to bounds + if (L < lower) L = lower; + if (R > upper) R = upper; + + // Sample from the interval, shrinking it on each rejection + double x1; + do { + x1 = L + (R - L) * unif(gen); + double gx1 = loglik_func->Evaluate(x1); + + if (gx1 >= logy) break; + + if (x1 > x0) { + R = x1; + } else { + L = x1; + } + } while (true); + + return x1; + } +}; + +} // namespace StochTree + +#endif // STOCHTREE_SLICE_SAMPLER_H_ diff --git a/include/stochtree/tree.h b/include/stochtree/tree.h index 3810e3cb..85ce7191 100644 --- a/include/stochtree/tree.h +++ b/include/stochtree/tree.h @@ -13,6 +13,9 @@ #include #include +#include +#include +#include #include #include diff --git a/include/stochtree/tree_sampler.h b/include/stochtree/tree_sampler.h index 675ef6c0..6b7579c6 100644 --- a/include/stochtree/tree_sampler.h +++ b/include/stochtree/tree_sampler.h @@ -8,13 +8,19 @@ #include #include #include -#include #include #include #include +#include +#include #include #include +#include +#include +#include +#include +#include #include namespace StochTree { @@ -22,7 +28,7 @@ namespace StochTree { /*! * \defgroup sampling_group Forest Sampler API * - * \brief Functions for sampling from a forest. The core interface of these functions, + * \brief Functions for sampling from a forest. The core interfce of these functions, * as used by the R, Python, and standalone C++ program, is defined by * \ref MCMCSampleOneIter, which runs one iteration of the MCMC sampler for a * given forest, and \ref GFRSampleOneIter, which runs one iteration of the @@ -147,7 +153,7 @@ static inline bool NodeNonConstant(ForestDataset& dataset, ForestTracker& tracke } static inline void AddSplitToModel(ForestTracker& tracker, ForestDataset& dataset, TreePrior& tree_prior, TreeSplit& split, std::mt19937& gen, Tree* tree, - int tree_num, int leaf_node, int feature_split, bool keep_sorted = false, int num_threads = -1) { + int tree_num, int leaf_node, int feature_split, bool keep_sorted = false) { // Use zeros as a "temporary" leaf values since we draw leaf parameters after tree sampling is complete if (tree->OutputDimension() > 1) { std::vector temp_leaf_values(tree->OutputDimension(), 0.); @@ -160,7 +166,7 @@ static inline void AddSplitToModel(ForestTracker& tracker, ForestDataset& datase int right_node = tree->RightChild(leaf_node); // Update the ForestTracker - tracker.AddSplit(dataset.GetCovariates(), split, feature_split, tree_num, leaf_node, left_node, right_node, keep_sorted, num_threads); + tracker.AddSplit(dataset.GetCovariates(), split, feature_split, tree_num, leaf_node, left_node, right_node, keep_sorted); } static inline void RemoveSplitFromModel(ForestTracker& tracker, ForestDataset& dataset, TreePrior& tree_prior, std::mt19937& gen, Tree* tree, @@ -295,6 +301,8 @@ static inline void UpdateResidualNewOutcome(ForestTracker& tracker, ColumnVector } } + + static inline void UpdateMeanModelTree(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, Tree* tree, int tree_num, bool requires_basis, std::function op, bool tree_new) { data_size_t n = dataset.GetCovariates().rows(); @@ -432,7 +440,7 @@ template EvaluateProposedSplit( ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, LeafModel& leaf_model, TreeSplit& split, int tree_num, int leaf_num, int split_feature, double global_variance, - int num_threads, LeafSuffStatConstructorArgs&... leaf_suff_stat_args + LeafSuffStatConstructorArgs&... leaf_suff_stat_args ) { // Initialize sufficient statistics LeafSuffStat node_suff_stat = LeafSuffStat(leaf_suff_stat_args...); @@ -440,11 +448,8 @@ static inline std::tuple EvaluatePropo LeafSuffStat right_suff_stat = LeafSuffStat(leaf_suff_stat_args...); // Accumulate sufficient statistics - AccumulateSuffStatProposed( - node_suff_stat, left_suff_stat, right_suff_stat, dataset, tracker, - residual, global_variance, split, tree_num, leaf_num, split_feature, num_threads, - leaf_suff_stat_args... - ); + AccumulateSuffStatProposed(node_suff_stat, left_suff_stat, right_suff_stat, dataset, tracker, + residual, global_variance, split, tree_num, leaf_num, split_feature, 1, leaf_suff_stat_args...); data_size_t left_n = left_suff_stat.n; data_size_t right_n = right_suff_stat.n; @@ -481,36 +486,164 @@ static inline std::tuple EvaluateExist template static inline void AdjustStateBeforeTreeSampling(ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, - ColumnVector& residual, TreePrior& tree_prior, bool backfitting, Tree* tree, int tree_num) { + ColumnVector& residual, TreePrior& tree_prior, bool backfitting, Tree* tree, int tree_num) { if constexpr (std::is_same_v) { UpdateCLogLogModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), false); } else if (backfitting) { - UpdateMeanModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), std::plus(), false); + UpdateMeanModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), std::plus(), false); } else { - // TODO: think about a generic way to store "state" corresponding to the other models? - UpdateVarModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), std::minus(), false); + // TODO: think about a generic way to store "state" corresponding to the other models? + UpdateVarModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), std::minus(), false); } } template static inline void AdjustStateAfterTreeSampling(ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, - ColumnVector& residual, TreePrior& tree_prior, bool backfitting, Tree* tree, int tree_num) { + ColumnVector& residual, TreePrior& tree_prior, bool backfitting, Tree* tree, int tree_num) { if constexpr (std::is_same_v) { UpdateCLogLogModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), true); } else if (backfitting) { - UpdateMeanModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), std::minus(), true); + UpdateMeanModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), std::minus(), true); } else { - // TODO: think about a generic way to store "state" corresponding to the other models? - UpdateVarModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), std::plus(), true); + // TODO: think about a generic way to store "state" corresponding to the other models? + UpdateVarModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), std::plus(), true); } } +template +static inline void EvaluateAllPossibleSplits( + ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, LeafModel& leaf_model, double global_variance, int tree_num, int split_node_id, + std::vector& log_cutpoint_evaluations, std::vector& cutpoint_features, std::vector& cutpoint_values, std::vector& cutpoint_feature_types, + data_size_t& valid_cutpoint_count, CutpointGridContainer& cutpoint_grid_container, data_size_t node_begin, data_size_t node_end, std::vector& variable_weights, + std::vector& feature_types, std::vector& feature_subset, LeafSuffStatConstructorArgs&... leaf_suff_stat_args +) { + // Initialize sufficient statistics + LeafSuffStat node_suff_stat = LeafSuffStat(leaf_suff_stat_args...); + LeafSuffStat left_suff_stat = LeafSuffStat(leaf_suff_stat_args...); + LeafSuffStat right_suff_stat = LeafSuffStat(leaf_suff_stat_args...); + + // Accumulate aggregate sufficient statistic for the node to be split + AccumulateSingleNodeSuffStat(node_suff_stat, dataset, tracker, residual, tree_num, split_node_id); + + // Compute the "no split" log marginal likelihood + double no_split_log_ml = leaf_model.NoSplitLogMarginalLikelihood(node_suff_stat, global_variance); + + // Unpack data + Eigen::MatrixXd covariates = dataset.GetCovariates(); + Eigen::VectorXd outcome = residual.GetData(); + Eigen::VectorXd var_weights; + bool has_weights = dataset.HasVarWeights(); + if (has_weights) var_weights = dataset.GetVarWeights(); + + // Minimum size of newly created leaf nodes (used to rule out invalid splits) + int32_t min_samples_in_leaf = tree_prior.GetMinSamplesLeaf(); + + // Compute sufficient statistics for each possible split + data_size_t num_cutpoints = 0; + bool valid_split = false; + data_size_t node_row_iter; + data_size_t current_bin_begin, current_bin_size, next_bin_begin; + data_size_t feature_sort_idx; + data_size_t row_iter_idx; + double outcome_val, outcome_val_sq; + FeatureType feature_type; + double feature_value = 0.0; + double cutoff_value = 0.0; + double log_split_eval = 0.0; + double split_log_ml; + for (int j = 0; j < covariates.cols(); j++) { + + if ((std::abs(variable_weights.at(j)) > kEpsilon) && (feature_subset[j])) { + // Enumerate cutpoint strides + cutpoint_grid_container.CalculateStrides(covariates, outcome, tracker.GetSortedNodeSampleTracker(), split_node_id, node_begin, node_end, j, feature_types); + + // Reset sufficient statistics + left_suff_stat.ResetSuffStat(); + right_suff_stat.ResetSuffStat(); + + // Iterate through possible cutpoints + int32_t num_feature_cutpoints = cutpoint_grid_container.NumCutpoints(j); + feature_type = feature_types[j]; + // Since we partition an entire cutpoint bin to the left, we must stop one bin before the total number of cutpoint bins + for (data_size_t cutpoint_idx = 0; cutpoint_idx < (num_feature_cutpoints - 1); cutpoint_idx++) { + current_bin_begin = cutpoint_grid_container.BinStartIndex(cutpoint_idx, j); + current_bin_size = cutpoint_grid_container.BinLength(cutpoint_idx, j); + next_bin_begin = cutpoint_grid_container.BinStartIndex(cutpoint_idx + 1, j); + + // Accumulate sufficient statistics for the left node + AccumulateCutpointBinSuffStat(left_suff_stat, tracker, cutpoint_grid_container, dataset, residual, + global_variance, tree_num, split_node_id, j, cutpoint_idx); + + // Compute the corresponding right node sufficient statistics + right_suff_stat.SubtractSuffStat(node_suff_stat, left_suff_stat); + + // Store the bin index as the "cutpoint value" - we can use this to query the actual split + // value or the set of split categories later on once a split is chose + cutoff_value = cutpoint_idx; + + // Only include cutpoint for consideration if it defines a valid split in the training data + valid_split = (left_suff_stat.SampleGreaterThanEqual(min_samples_in_leaf) && + right_suff_stat.SampleGreaterThanEqual(min_samples_in_leaf)); + if (valid_split) { + num_cutpoints++; + // Add to split rule vector + cutpoint_feature_types.push_back(feature_type); + cutpoint_features.push_back(j); + cutpoint_values.push_back(cutoff_value); + // Add the log marginal likelihood of the split to the split eval vector + split_log_ml = leaf_model.SplitLogMarginalLikelihood(left_suff_stat, right_suff_stat, global_variance); + log_cutpoint_evaluations.push_back(split_log_ml); + } + } + } + + } + + // Add the log marginal likelihood of the "no-split" option (adjusted for tree prior and cutpoint size per the XBART paper) + cutpoint_features.push_back(-1); + cutpoint_values.push_back(std::numeric_limits::max()); + cutpoint_feature_types.push_back(FeatureType::kNumeric); + log_cutpoint_evaluations.push_back(no_split_log_ml); + + // Update valid cutpoint count + valid_cutpoint_count = num_cutpoints; +} + +template +static inline void EvaluateCutpoints(Tree* tree, ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, TreePrior& tree_prior, + std::mt19937& gen, int tree_num, double global_variance, int cutpoint_grid_size, int node_id, data_size_t node_begin, data_size_t node_end, + std::vector& log_cutpoint_evaluations, std::vector& cutpoint_features, std::vector& cutpoint_values, + std::vector& cutpoint_feature_types, data_size_t& valid_cutpoint_count, std::vector& variable_weights, + std::vector& feature_types, CutpointGridContainer& cutpoint_grid_container, + std::vector& feature_subset, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { + // Evaluate all possible cutpoints according to the leaf node model, + // recording their log-likelihood and other split information in a series of vectors. + // The last element of these vectors concerns the "no-split" option. + EvaluateAllPossibleSplits( + dataset, tracker, residual, tree_prior, gen, leaf_model, global_variance, tree_num, node_id, log_cutpoint_evaluations, + cutpoint_features, cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, + node_begin, node_end, variable_weights, feature_types, feature_subset, leaf_suff_stat_args... + ); + + // Compute an adjustment to reflect the no split prior probability and the number of cutpoints + double bart_prior_no_split_adj; + double alpha = tree_prior.GetAlpha(); + double beta = tree_prior.GetBeta(); + int node_depth = tree->GetDepth(node_id); + if (valid_cutpoint_count == 0) { + bart_prior_no_split_adj = std::log(((std::pow(1+node_depth, beta))/alpha) - 1.0); + } else { + bart_prior_no_split_adj = std::log(((std::pow(1+node_depth, beta))/alpha) - 1.0) + std::log(valid_cutpoint_count); + } + log_cutpoint_evaluations[log_cutpoint_evaluations.size()-1] += bart_prior_no_split_adj; +} + template static inline void SampleSplitRule(Tree* tree, ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, int tree_num, double global_variance, int cutpoint_grid_size, std::unordered_map>& node_index_map, std::deque& split_queue, int node_id, data_size_t node_begin, data_size_t node_end, std::vector& variable_weights, - std::vector& feature_types, std::vector feature_subset, int num_threads, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { + std::vector& feature_types, std::vector feature_subset, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { // Leaf depth int leaf_depth = tree->GetDepth(node_id); @@ -518,153 +651,41 @@ static inline void SampleSplitRule(Tree* tree, ForestTracker& tracker, LeafModel int32_t max_depth = tree_prior.GetMaxDepth(); if ((max_depth == -1) || (leaf_depth < max_depth)) { - - // Vector of vectors to store results for each feature - int p = dataset.NumCovariates(); - std::vector> feature_log_cutpoint_evaluations(p+1); - std::vector> feature_cutpoint_values(p+1); - std::vector feature_cutpoint_counts(p+1, 0); + + // Cutpoint enumeration + std::vector log_cutpoint_evaluations; + std::vector cutpoint_features; + std::vector cutpoint_values; + std::vector cutpoint_feature_types; StochTree::data_size_t valid_cutpoint_count; - - // Evaluate all possible cutpoints according to the leaf node model, - // recording their log-likelihood and other split information in a series of vectors. - - // Initialize node sufficient statistics - LeafSuffStat node_suff_stat = LeafSuffStat(leaf_suff_stat_args...); - - // Accumulate aggregate sufficient statistic for the node to be split - AccumulateSingleNodeSuffStat(node_suff_stat, dataset, tracker, residual, tree_num, node_id); - - // Compute the "no split" log marginal likelihood - double no_split_log_ml = leaf_model.NoSplitLogMarginalLikelihood(node_suff_stat, global_variance); - - // Unpack data - Eigen::MatrixXd& covariates = dataset.GetCovariates(); - Eigen::VectorXd& outcome = residual.GetData(); - Eigen::VectorXd var_weights; - bool has_weights = dataset.HasVarWeights(); - if (has_weights) var_weights = dataset.GetVarWeights(); - - // Minimum size of newly created leaf nodes (used to rule out invalid splits) - int32_t min_samples_in_leaf = tree_prior.GetMinSamplesLeaf(); - - // Compute sufficient statistics for each possible split - data_size_t num_cutpoints = 0; - if (num_threads == -1) { - num_threads = GetOptimalThreadCount(static_cast(covariates.cols() * covariates.rows())); - } - - // Initialize cutpoint grid container - CutpointGridContainer cutpoint_grid_container(covariates, outcome, cutpoint_grid_size); - - // Evaluate all possible splits for each feature in parallel - StochTree::ParallelFor(0, covariates.cols(), num_threads, [&](int j) { - if ((std::abs(variable_weights.at(j)) > kEpsilon) && (feature_subset[j])) { - // Enumerate cutpoint strides - cutpoint_grid_container.CalculateStrides(covariates, outcome, tracker.GetSortedNodeSampleTracker(), node_id, node_begin, node_end, j, feature_types); - - // Left and right node sufficient statistics - LeafSuffStat left_suff_stat = LeafSuffStat(leaf_suff_stat_args...); - LeafSuffStat right_suff_stat = LeafSuffStat(leaf_suff_stat_args...); - - // Iterate through possible cutpoints - int32_t num_feature_cutpoints = cutpoint_grid_container.NumCutpoints(j); - FeatureType feature_type = feature_types[j]; - // Since we partition an entire cutpoint bin to the left, we must stop one bin before the total number of cutpoint bins - for (data_size_t cutpoint_idx = 0; cutpoint_idx < (num_feature_cutpoints - 1); cutpoint_idx++) { - data_size_t current_bin_begin = cutpoint_grid_container.BinStartIndex(cutpoint_idx, j); - data_size_t current_bin_size = cutpoint_grid_container.BinLength(cutpoint_idx, j); - data_size_t next_bin_begin = cutpoint_grid_container.BinStartIndex(cutpoint_idx + 1, j); - - // Accumulate sufficient statistics for the left node - AccumulateCutpointBinSuffStat(left_suff_stat, tracker, cutpoint_grid_container, dataset, residual, - global_variance, tree_num, node_id, j, cutpoint_idx); - - // Compute the corresponding right node sufficient statistics - right_suff_stat.SubtractSuffStat(node_suff_stat, left_suff_stat); - - // Store the bin index as the "cutpoint value" - we can use this to query the actual split - // value or the set of split categories later on once a split is chose - double cutoff_value = cutpoint_idx; - - // Only include cutpoint for consideration if it defines a valid split in the training data - bool valid_split = (left_suff_stat.SampleGreaterThanEqual(min_samples_in_leaf) && - right_suff_stat.SampleGreaterThanEqual(min_samples_in_leaf)); - if (valid_split) { - feature_cutpoint_counts[j]++; - // Add to split rule vector - feature_cutpoint_values[j].push_back(cutoff_value); - // Add the log marginal likelihood of the split to the split eval vector - double split_log_ml = leaf_model.SplitLogMarginalLikelihood(left_suff_stat, right_suff_stat, global_variance); - feature_log_cutpoint_evaluations[j].push_back(split_log_ml); - } - } - } - }); - - // Compute total number of cutpoints - valid_cutpoint_count = std::accumulate(feature_cutpoint_counts.begin(), feature_cutpoint_counts.end(), 0); - - // Add the log marginal likelihood of the "no-split" option (adjusted for tree prior and cutpoint size per the XBART paper) - feature_log_cutpoint_evaluations[covariates.cols()].push_back(no_split_log_ml); - - // Compute an adjustment to reflect the no split prior probability and the number of cutpoints - double bart_prior_no_split_adj; - double alpha = tree_prior.GetAlpha(); - double beta = tree_prior.GetBeta(); - int node_depth = tree->GetDepth(node_id); - if (valid_cutpoint_count == 0) { - bart_prior_no_split_adj = std::log(((std::pow(1+node_depth, beta))/alpha) - 1.0); - } else { - bart_prior_no_split_adj = std::log(((std::pow(1+node_depth, beta))/alpha) - 1.0) + std::log(valid_cutpoint_count); - } - feature_log_cutpoint_evaluations[covariates.cols()][0] += bart_prior_no_split_adj; - + CutpointGridContainer cutpoint_grid_container(dataset.GetCovariates(), residual.GetData(), cutpoint_grid_size); + EvaluateCutpoints( + tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, global_variance, + cutpoint_grid_size, node_id, node_begin, node_end, log_cutpoint_evaluations, cutpoint_features, + cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, variable_weights, feature_types, + cutpoint_grid_container, feature_subset, leaf_suff_stat_args... + ); + // TODO: maybe add some checks here? // Convert log marginal likelihood to marginal likelihood, normalizing by the maximum log-likelihood - double largest_ml = -std::numeric_limits::infinity(); - for (int j = 0; j < p + 1; j++) { - if (feature_log_cutpoint_evaluations[j].size() > 0) { - double feature_max_ml = *std::max_element(feature_log_cutpoint_evaluations[j].begin(), feature_log_cutpoint_evaluations[j].end());; - largest_ml = std::max(largest_ml, feature_max_ml); - } - } - std::vector> feature_cutpoint_evaluations(p+1); - for (int j = 0; j < p + 1; j++) { - if (feature_log_cutpoint_evaluations[j].size() > 0) { - feature_cutpoint_evaluations[j].resize(feature_log_cutpoint_evaluations[j].size()); - for (int i = 0; i < feature_log_cutpoint_evaluations[j].size(); i++) { - feature_cutpoint_evaluations[j][i] = std::exp(feature_log_cutpoint_evaluations[j][i] - largest_ml); - } - } + double largest_mll = *std::max_element(log_cutpoint_evaluations.begin(), log_cutpoint_evaluations.end()); + std::vector cutpoint_evaluations(log_cutpoint_evaluations.size()); + for (data_size_t i = 0; i < log_cutpoint_evaluations.size(); i++){ + cutpoint_evaluations[i] = std::exp(log_cutpoint_evaluations[i] - largest_mll); } - - // Compute sum of marginal likelihoods for each feature - std::vector feature_total_cutpoint_evaluations(p+1, 0.0); - for (int j = 0; j < p + 1; j++) { - if (feature_log_cutpoint_evaluations[j].size() > 0) { - feature_total_cutpoint_evaluations[j] = std::accumulate(feature_cutpoint_evaluations[j].begin(), feature_cutpoint_evaluations[j].end(), 0.0); - } else { - feature_total_cutpoint_evaluations[j] = 0.0; - } - } - - // First, sample a feature according to feature_total_cutpoint_evaluations - std::discrete_distribution feature_dist(feature_total_cutpoint_evaluations.begin(), feature_total_cutpoint_evaluations.end()); - int feature_chosen = feature_dist(gen); - - // Then, sample a cutpoint according to feature_cutpoint_evaluations[feature_chosen] - std::discrete_distribution cutpoint_dist(feature_cutpoint_evaluations[feature_chosen].begin(), feature_cutpoint_evaluations[feature_chosen].end()); - data_size_t cutpoint_chosen = cutpoint_dist(gen); - if (feature_chosen == p){ + // Sample the split (including a "no split" option) + std::discrete_distribution split_dist(cutpoint_evaluations.begin(), cutpoint_evaluations.end()); + data_size_t split_chosen = split_dist(gen); + + if (split_chosen == valid_cutpoint_count){ // "No split" sampled, don't split or add any nodes to split queue return; } else { // Split sampled - int feature_split = feature_chosen; - FeatureType feature_type = feature_types[feature_split]; - double split_value = feature_cutpoint_values[feature_split][cutpoint_chosen]; + int feature_split = cutpoint_features[split_chosen]; + FeatureType feature_type = cutpoint_feature_types[split_chosen]; + double split_value = cutpoint_values[split_chosen]; // Perform all of the relevant "split" operations in the model, tree and training dataset // Compute node sample size @@ -699,7 +720,7 @@ static inline void SampleSplitRule(Tree* tree, ForestTracker& tracker, LeafModel } // Add split to tree and trackers - AddSplitToModel(tracker, dataset, tree_prior, tree_split, gen, tree, tree_num, node_id, feature_split, true, num_threads); + AddSplitToModel(tracker, dataset, tree_prior, tree_split, gen, tree, tree_num, node_id, feature_split, true); // Determine the number of observation in the newly created left node int left_node = tree->LeftChild(node_id); @@ -725,7 +746,7 @@ template & variable_weights, int tree_num, double global_variance, std::vector& feature_types, int cutpoint_grid_size, - int num_features_subsample, int num_threads, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { + int num_features_subsample, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { int root_id = Tree::kRoot; int curr_node_id; data_size_t curr_node_begin; @@ -781,8 +802,7 @@ static inline void GFRSampleTreeOneIter(Tree* tree, ForestTracker& tracker, Fore SampleSplitRule( tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, global_variance, cutpoint_grid_size, node_index_map, split_queue, curr_node_id, curr_node_begin, curr_node_end, variable_weights, feature_types, - feature_subset, num_threads, leaf_suff_stat_args... - ); + feature_subset, leaf_suff_stat_args...); } } @@ -820,7 +840,7 @@ template & variable_weights, std::vector& sweep_update_indices, double global_variance, std::vector& feature_types, int cutpoint_grid_size, - bool keep_forest, bool pre_initialized, bool backfitting, int num_features_subsample, int num_threads, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { + bool keep_forest, bool pre_initialized, bool backfitting, int num_features_subsample, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { // Run the GFR algorithm for each tree int num_trees = forests.NumTrees(); for (const int& i : sweep_update_indices) { @@ -840,7 +860,7 @@ static inline void GFRSampleOneIter(TreeEnsemble& active_forest, ForestTracker& GFRSampleTreeOneIter( tree, tracker, forests, leaf_model, dataset, residual, tree_prior, gen, variable_weights, i, global_variance, feature_types, cutpoint_grid_size, - num_features_subsample, num_threads, leaf_suff_stat_args... + num_features_subsample, leaf_suff_stat_args... ); // Sample leaf parameters for tree i @@ -862,7 +882,7 @@ static inline void GFRSampleOneIter(TreeEnsemble& active_forest, ForestTracker& template static inline void MCMCGrowTreeOneIter(Tree* tree, ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, int tree_num, std::vector& variable_weights, - double global_variance, double prob_grow_old, int num_threads, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { + double global_variance, double prob_grow_old, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { // Extract dataset information data_size_t n = dataset.GetCovariates().rows(); @@ -907,7 +927,7 @@ static inline void MCMCGrowTreeOneIter(Tree* tree, ForestTracker& tracker, LeafM // Compute the marginal likelihood of split and no split, given the leaf prior std::tuple split_eval = EvaluateProposedSplit( - dataset, tracker, residual, leaf_model, split, tree_num, leaf_chosen, var_chosen, global_variance, num_threads, leaf_suff_stat_args... + dataset, tracker, residual, leaf_model, split, tree_num, leaf_chosen, var_chosen, global_variance, leaf_suff_stat_args... ); double split_log_marginal_likelihood = std::get<0>(split_eval); double no_split_log_marginal_likelihood = std::get<1>(split_eval); @@ -957,7 +977,7 @@ static inline void MCMCGrowTreeOneIter(Tree* tree, ForestTracker& tracker, LeafM double log_acceptance_prob = std::log(mh_accept(gen)); if (log_acceptance_prob <= log_mh_ratio) { accept = true; - AddSplitToModel(tracker, dataset, tree_prior, split, gen, tree, tree_num, leaf_chosen, var_chosen, false, num_threads); + AddSplitToModel(tracker, dataset, tree_prior, split, gen, tree, tree_num, leaf_chosen, var_chosen, false); } else { accept = false; } @@ -970,7 +990,7 @@ static inline void MCMCGrowTreeOneIter(Tree* tree, ForestTracker& tracker, LeafM template static inline void MCMCPruneTreeOneIter(Tree* tree, ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, - TreePrior& tree_prior, std::mt19937& gen, int tree_num, double global_variance, int num_threads, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { + TreePrior& tree_prior, std::mt19937& gen, int tree_num, double global_variance, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { // Choose a "leaf parent" node at random int num_leaves = tree->NumLeaves(); int num_leaf_parents = tree->NumLeafParents(); @@ -1049,7 +1069,7 @@ static inline void MCMCPruneTreeOneIter(Tree* tree, ForestTracker& tracker, Leaf template static inline void MCMCSampleTreeOneIter(Tree* tree, ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector& variable_weights, - int tree_num, double global_variance, int num_threads, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { + int tree_num, double global_variance, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { // Determine whether it is possible to grow any of the leaves bool grow_possible = false; std::vector leaves = tree->GetLeaves(); @@ -1089,11 +1109,11 @@ static inline void MCMCSampleTreeOneIter(Tree* tree, ForestTracker& tracker, For if (step_chosen == 0) { MCMCGrowTreeOneIter( - tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, variable_weights, global_variance, prob_grow, num_threads, leaf_suff_stat_args... + tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, variable_weights, global_variance, prob_grow, leaf_suff_stat_args... ); } else { MCMCPruneTreeOneIter( - tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, global_variance, num_threads, leaf_suff_stat_args... + tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, global_variance, leaf_suff_stat_args... ); } } @@ -1128,8 +1148,7 @@ static inline void MCMCSampleTreeOneIter(Tree* tree, ForestTracker& tracker, For template static inline void MCMCSampleOneIter(TreeEnsemble& active_forest, ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector& variable_weights, - std::vector& sweep_update_indices, double global_variance, bool keep_forest, bool pre_initialized, bool backfitting, int num_threads, - LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { + std::vector& sweep_update_indices, double global_variance, bool keep_forest, bool pre_initialized, bool backfitting, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { // Run the MCMC algorithm for each tree int num_trees = forests.NumTrees(); for (const int& i : sweep_update_indices) { @@ -1144,7 +1163,7 @@ static inline void MCMCSampleOneIter(TreeEnsemble& active_forest, ForestTracker& tree = active_forest.GetTree(i); MCMCSampleTreeOneIter( tree, tracker, forests, leaf_model, dataset, residual, tree_prior, gen, variable_weights, i, - global_variance, num_threads, leaf_suff_stat_args... + global_variance, leaf_suff_stat_args... ); // Sample leaf parameters for tree i diff --git a/include/stochtree/variance_model.h b/include/stochtree/variance_model.h index b1c2dabe..79b8831f 100644 --- a/include/stochtree/variance_model.h +++ b/include/stochtree/variance_model.h @@ -12,7 +12,11 @@ #include #include +#include #include +#include +#include +#include namespace StochTree { diff --git a/man/bart.Rd b/man/bart.Rd index 66a9b9ad..c11c619b 100644 --- a/man/bart.Rd +++ b/man/bart.Rd @@ -136,9 +136,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 @@ -153,6 +153,6 @@ X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, +bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, num_gfr = 10, num_burnin = 0, num_mcmc = 10) } diff --git a/man/bcf.Rd b/man/bcf.Rd index 01e5fab8..f7d42e93 100644 --- a/man/bcf.Rd +++ b/man/bcf.Rd @@ -162,21 +162,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -199,8 +199,8 @@ mu_test <- mu_x[test_inds] mu_train <- mu_x[train_inds] tau_test <- tau_x[test_inds] tau_train <- tau_x[train_inds] -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, X_test = X_test, Z_test = Z_test, - propensity_test = pi_test, num_gfr = 10, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, X_test = X_test, Z_test = Z_test, + propensity_test = pi_test, num_gfr = 10, num_burnin = 0, num_mcmc = 10) } diff --git a/man/cloglog_ordinal_bart.Rd b/man/cloglog_ordinal_bart.Rd new file mode 100644 index 00000000..9c2aed51 --- /dev/null +++ b/man/cloglog_ordinal_bart.Rd @@ -0,0 +1,47 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/cloglog_ordinal_bart.R +\name{cloglog_ordinal_bart} +\alias{cloglog_ordinal_bart} +\title{Run the BART algorithm for ordinal outcomes using a complementary log-log link} +\usage{ +cloglog_ordinal_bart( + X, + y, + X_test = NULL, + n_trees = 50, + n_samples_mcmc = 500, + n_burnin = 250, + n_thin = 1, + alpha_gamma = 2, + beta_gamma = 2, + variable_weights = NULL, + feature_types = NULL, + seed = NULL +) +} +\arguments{ +\item{X}{A numeric matrix of predictors (training data).} + +\item{y}{A numeric vector of ordinal outcomes (positive integers starting from 1).} + +\item{X_test}{An optional numeric matrix of predictors (test data).} + +\item{n_trees}{Number of trees in the BART ensemble. Default: \code{50}.} + +\item{n_samples_mcmc}{Total number of MCMC samples to draw. Default: \code{500}.} + +\item{n_burnin}{Number of burn-in samples to discard. Default: \code{250}.} + +\item{n_thin}{Thinning interval for MCMC samples. Default: \code{1}.} + +\item{alpha_gamma}{Shape parameter for the log-gamma prior on cutpoints. Default: \code{2.0}.} + +\item{beta_gamma}{Rate parameter for the log-gamma prior on cutpoints. Default: \code{2.0}.} + +\item{variable_weights}{Optional vector of variable weights for splitting (default: equal weights).} + +\item{feature_types}{Optional vector indicating feature types (0 for continuous, 1 for categorical; default: all continuous).} +} +\description{ +Run the BART algorithm for ordinal outcomes using a complementary log-log link +} diff --git a/man/createBARTModelFromCombinedJson.Rd b/man/createBARTModelFromCombinedJson.Rd index 35d185c3..83d61d0d 100644 --- a/man/createBARTModelFromCombinedJson.Rd +++ b/man/createBARTModelFromCombinedJson.Rd @@ -22,9 +22,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 @@ -38,7 +38,7 @@ X_test <- X[test_inds,] X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, +bart_model <- bart(X_train = X_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) bart_json <- list(saveBARTModelToJson(bart_model)) bart_model_roundtrip <- createBARTModelFromCombinedJson(bart_json) diff --git a/man/createBARTModelFromCombinedJsonString.Rd b/man/createBARTModelFromCombinedJsonString.Rd index a8470dee..7a17484a 100644 --- a/man/createBARTModelFromCombinedJsonString.Rd +++ b/man/createBARTModelFromCombinedJsonString.Rd @@ -22,9 +22,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 @@ -38,7 +38,7 @@ X_test <- X[test_inds,] X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, +bart_model <- bart(X_train = X_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) bart_json_string_list <- list(saveBARTModelToJsonString(bart_model)) bart_model_roundtrip <- createBARTModelFromCombinedJsonString(bart_json_string_list) diff --git a/man/createBARTModelFromJson.Rd b/man/createBARTModelFromJson.Rd index 57686122..68a02f0e 100644 --- a/man/createBARTModelFromJson.Rd +++ b/man/createBARTModelFromJson.Rd @@ -22,9 +22,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 @@ -38,7 +38,7 @@ X_test <- X[test_inds,] X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, +bart_model <- bart(X_train = X_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) bart_json <- saveBARTModelToJson(bart_model) bart_model_roundtrip <- createBARTModelFromJson(bart_json) diff --git a/man/createBARTModelFromJsonFile.Rd b/man/createBARTModelFromJsonFile.Rd index f714a94a..7608d8d2 100644 --- a/man/createBARTModelFromJsonFile.Rd +++ b/man/createBARTModelFromJsonFile.Rd @@ -22,9 +22,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 @@ -38,7 +38,7 @@ X_test <- X[test_inds,] X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, +bart_model <- bart(X_train = X_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) tmpjson <- tempfile(fileext = ".json") saveBARTModelToJsonFile(bart_model, file.path(tmpjson)) diff --git a/man/createBARTModelFromJsonString.Rd b/man/createBARTModelFromJsonString.Rd index 67068fd0..0748d97a 100644 --- a/man/createBARTModelFromJsonString.Rd +++ b/man/createBARTModelFromJsonString.Rd @@ -22,9 +22,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 @@ -38,7 +38,7 @@ X_test <- X[test_inds,] X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, +bart_model <- bart(X_train = X_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) bart_json <- saveBARTModelToJsonString(bart_model) bart_model_roundtrip <- createBARTModelFromJsonString(bart_json) diff --git a/man/createBCFModelFromCombinedJson.Rd b/man/createBCFModelFromCombinedJson.Rd index 6f29569e..24c82e4f 100644 --- a/man/createBCFModelFromCombinedJson.Rd +++ b/man/createBCFModelFromCombinedJson.Rd @@ -22,21 +22,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -70,13 +70,13 @@ rfx_basis_test <- rfx_basis[test_inds,] rfx_basis_train <- rfx_basis[train_inds,] rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, - rfx_group_ids_train = rfx_group_ids_train, - rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + rfx_basis_train = rfx_basis_train, X_test = X_test, + Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_test = rfx_basis_test, + rfx_basis_test = rfx_basis_test, num_gfr = 10, num_burnin = 0, num_mcmc = 10) bcf_json_list <- list(saveBCFModelToJson(bcf_model)) bcf_model_roundtrip <- createBCFModelFromCombinedJson(bcf_json_list) diff --git a/man/createBCFModelFromCombinedJsonString.Rd b/man/createBCFModelFromCombinedJsonString.Rd index bd7e63f2..e0522f75 100644 --- a/man/createBCFModelFromCombinedJsonString.Rd +++ b/man/createBCFModelFromCombinedJsonString.Rd @@ -22,21 +22,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -70,13 +70,13 @@ rfx_basis_test <- rfx_basis[test_inds,] rfx_basis_train <- rfx_basis[train_inds,] rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, - rfx_group_ids_train = rfx_group_ids_train, - rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + rfx_basis_train = rfx_basis_train, X_test = X_test, + Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_test = rfx_basis_test, + rfx_basis_test = rfx_basis_test, num_gfr = 10, num_burnin = 0, num_mcmc = 10) bcf_json_string_list <- list(saveBCFModelToJsonString(bcf_model)) bcf_model_roundtrip <- createBCFModelFromCombinedJsonString(bcf_json_string_list) diff --git a/man/createBCFModelFromJson.Rd b/man/createBCFModelFromJson.Rd index a579b140..35cff7ce 100644 --- a/man/createBCFModelFromJson.Rd +++ b/man/createBCFModelFromJson.Rd @@ -22,21 +22,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -72,15 +72,15 @@ rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] mu_params <- list(sample_sigma2_leaf = TRUE) tau_params <- list(sample_sigma2_leaf = FALSE) -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, - rfx_group_ids_train = rfx_group_ids_train, - rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + rfx_basis_train = rfx_basis_train, X_test = X_test, + Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_test = rfx_basis_test, - num_gfr = 10, num_burnin = 0, num_mcmc = 10, - prognostic_forest_params = mu_params, + rfx_basis_test = rfx_basis_test, + num_gfr = 10, num_burnin = 0, num_mcmc = 10, + prognostic_forest_params = mu_params, treatment_effect_forest_params = tau_params) bcf_json <- saveBCFModelToJson(bcf_model) bcf_model_roundtrip <- createBCFModelFromJson(bcf_json) diff --git a/man/createBCFModelFromJsonFile.Rd b/man/createBCFModelFromJsonFile.Rd index 2661d4de..a2496797 100644 --- a/man/createBCFModelFromJsonFile.Rd +++ b/man/createBCFModelFromJsonFile.Rd @@ -22,21 +22,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -72,15 +72,15 @@ rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] mu_params <- list(sample_sigma2_leaf = TRUE) tau_params <- list(sample_sigma2_leaf = FALSE) -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, - rfx_group_ids_train = rfx_group_ids_train, - rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + rfx_basis_train = rfx_basis_train, X_test = X_test, + Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_test = rfx_basis_test, - num_gfr = 10, num_burnin = 0, num_mcmc = 10, - prognostic_forest_params = mu_params, + rfx_basis_test = rfx_basis_test, + num_gfr = 10, num_burnin = 0, num_mcmc = 10, + prognostic_forest_params = mu_params, treatment_effect_forest_params = tau_params) tmpjson <- tempfile(fileext = ".json") saveBCFModelToJsonFile(bcf_model, file.path(tmpjson)) diff --git a/man/createBCFModelFromJsonString.Rd b/man/createBCFModelFromJsonString.Rd index 5f34724c..cc944f85 100644 --- a/man/createBCFModelFromJsonString.Rd +++ b/man/createBCFModelFromJsonString.Rd @@ -22,21 +22,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -70,13 +70,13 @@ rfx_basis_test <- rfx_basis[test_inds,] rfx_basis_train <- rfx_basis[train_inds,] rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, - rfx_group_ids_train = rfx_group_ids_train, - rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + rfx_basis_train = rfx_basis_train, X_test = X_test, + Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_test = rfx_basis_test, + rfx_basis_test = rfx_basis_test, num_gfr = 10, num_burnin = 0, num_mcmc = 10) bcf_json <- saveBCFModelToJsonString(bcf_model) bcf_model_roundtrip <- createBCFModelFromJsonString(bcf_json) diff --git a/man/createForestModel.Rd b/man/createForestModel.Rd index d9000925..d7a1adae 100644 --- a/man/createForestModel.Rd +++ b/man/createForestModel.Rd @@ -30,10 +30,10 @@ max_depth <- 10 feature_types <- as.integer(rep(0, p)) X <- matrix(runif(n*p), ncol = p) forest_dataset <- createForestDataset(X) -forest_model_config <- createForestModelConfig(feature_types=feature_types, - num_trees=num_trees, num_features=p, - num_observations=n, alpha=alpha, beta=beta, - min_samples_leaf=min_samples_leaf, +forest_model_config <- createForestModelConfig(feature_types=feature_types, + num_trees=num_trees, num_features=p, + num_observations=n, alpha=alpha, beta=beta, + min_samples_leaf=min_samples_leaf, max_depth=max_depth, leaf_model_type=1) global_model_config <- createGlobalModelConfig(global_error_variance=1.0) forest_model <- createForestModel(forest_dataset, forest_model_config, global_model_config) diff --git a/man/getRandomEffectSamples.bartmodel.Rd b/man/getRandomEffectSamples.bartmodel.Rd index 0da1eb98..149586a8 100644 --- a/man/getRandomEffectSamples.bartmodel.Rd +++ b/man/getRandomEffectSamples.bartmodel.Rd @@ -24,9 +24,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) snr <- 3 @@ -51,11 +51,11 @@ rfx_basis_test <- rfx_basis[test_inds,] rfx_basis_train <- rfx_basis[train_inds,] rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - rfx_group_ids_train = rfx_group_ids_train, - rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_train = rfx_basis_train, - rfx_basis_test = rfx_basis_test, +bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, + rfx_group_ids_train = rfx_group_ids_train, + rfx_group_ids_test = rfx_group_ids_test, + rfx_basis_train = rfx_basis_train, + rfx_basis_test = rfx_basis_test, num_gfr = 10, num_burnin = 0, num_mcmc = 10) rfx_samples <- getRandomEffectSamples(bart_model) } diff --git a/man/getRandomEffectSamples.bcfmodel.Rd b/man/getRandomEffectSamples.bcfmodel.Rd index 6769de62..08a8eae4 100644 --- a/man/getRandomEffectSamples.bcfmodel.Rd +++ b/man/getRandomEffectSamples.bcfmodel.Rd @@ -24,21 +24,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -74,15 +74,15 @@ rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] mu_params <- list(sample_sigma2_leaf = TRUE) tau_params <- list(sample_sigma2_leaf = FALSE) -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, - rfx_group_ids_train = rfx_group_ids_train, - rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + rfx_basis_train = rfx_basis_train, X_test = X_test, + Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_test = rfx_basis_test, - num_gfr = 10, num_burnin = 0, num_mcmc = 10, - prognostic_forest_params = mu_params, + rfx_basis_test = rfx_basis_test, + num_gfr = 10, num_burnin = 0, num_mcmc = 10, + prognostic_forest_params = mu_params, treatment_effect_forest_params = tau_params) rfx_samples <- getRandomEffectSamples(bcf_model) } diff --git a/man/predict.bartmodel.Rd b/man/predict.bartmodel.Rd index 2afccbf6..8a0a47bf 100644 --- a/man/predict.bartmodel.Rd +++ b/man/predict.bartmodel.Rd @@ -40,9 +40,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 @@ -56,7 +56,7 @@ X_test <- X[test_inds,] X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, +bart_model <- bart(X_train = X_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) y_hat_test <- predict(bart_model, X_test)$y_hat } diff --git a/man/predict.bcfmodel.Rd b/man/predict.bcfmodel.Rd index ff315808..907e5308 100644 --- a/man/predict.bcfmodel.Rd +++ b/man/predict.bcfmodel.Rd @@ -42,21 +42,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -79,8 +79,8 @@ mu_test <- mu_x[test_inds] mu_train <- mu_x[train_inds] tau_test <- tau_x[test_inds] tau_train <- tau_x[train_inds] -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, num_gfr = 10, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) preds <- predict(bcf_model, X_test, Z_test, pi_test) } diff --git a/man/preprocessPredictionData.Rd b/man/preprocessPredictionData.Rd index f881fda8..a6382e69 100644 --- a/man/preprocessPredictionData.Rd +++ b/man/preprocessPredictionData.Rd @@ -22,7 +22,7 @@ types. Matrices will be passed through assuming all columns are numeric. } \examples{ cov_df <- data.frame(x1 = 1:5, x2 = 5:1, x3 = 6:10) -metadata <- list(num_ordered_cat_vars = 0, num_unordered_cat_vars = 0, +metadata <- list(num_ordered_cat_vars = 0, num_unordered_cat_vars = 0, num_numeric_vars = 3, numeric_vars = c("x1", "x2", "x3")) X_preprocessed <- preprocessPredictionData(cov_df, metadata) } diff --git a/man/resetForestModel.Rd b/man/resetForestModel.Rd index f0fec6ca..b02158d4 100644 --- a/man/resetForestModel.Rd +++ b/man/resetForestModel.Rd @@ -48,23 +48,23 @@ y <- -5 + 10*(X[,1] > 0.5) + rnorm(n) outcome <- createOutcome(y) rng <- createCppRNG(1234) global_model_config <- createGlobalModelConfig(global_error_variance=sigma2) -forest_model_config <- createForestModelConfig(feature_types=feature_types, - num_trees=num_trees, num_observations=n, - num_features=p, alpha=alpha, beta=beta, - min_samples_leaf=min_samples_leaf, - max_depth=max_depth, - variable_weights=variable_weights, - cutpoint_grid_size=cutpoint_grid_size, - leaf_model_type=leaf_model, +forest_model_config <- createForestModelConfig(feature_types=feature_types, + num_trees=num_trees, num_observations=n, + num_features=p, alpha=alpha, beta=beta, + min_samples_leaf=min_samples_leaf, + max_depth=max_depth, + variable_weights=variable_weights, + cutpoint_grid_size=cutpoint_grid_size, + leaf_model_type=leaf_model, leaf_model_scale=leaf_scale) forest_model <- createForestModel(forest_dataset, forest_model_config, global_model_config) active_forest <- createForest(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated) -forest_samples <- createForestSamples(num_trees, leaf_dimension, +forest_samples <- createForestSamples(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated) active_forest$prepare_for_sampler(forest_dataset, outcome, forest_model, 0, 0.) forest_model$sample_one_iteration( - forest_dataset, outcome, forest_samples, active_forest, - rng, forest_model_config, global_model_config, + forest_dataset, outcome, forest_samples, active_forest, + rng, forest_model_config, global_model_config, keep_forest = TRUE, gfr = FALSE ) resetActiveForest(active_forest, forest_samples, 0) diff --git a/man/resetRandomEffectsModel.Rd b/man/resetRandomEffectsModel.Rd index fec99b77..b032ccc2 100644 --- a/man/resetRandomEffectsModel.Rd +++ b/man/resetRandomEffectsModel.Rd @@ -49,8 +49,8 @@ rfx_model$set_group_parameter_cov(sigma_xi_init) rfx_model$set_variance_prior_shape(sigma_xi_shape) rfx_model$set_variance_prior_scale(sigma_xi_scale) for (i in 1:3) { - rfx_model$sample_random_effect(rfx_dataset=rfx_dataset, residual=outcome, - rfx_tracker=rfx_tracker, rfx_samples=rfx_samples, + rfx_model$sample_random_effect(rfx_dataset=rfx_dataset, residual=outcome, + rfx_tracker=rfx_tracker, rfx_samples=rfx_samples, keep_sample=TRUE, global_variance=1.0, rng=rng) } resetRandomEffectsModel(rfx_model, rfx_samples, 0, 1.0) diff --git a/man/resetRandomEffectsTracker.Rd b/man/resetRandomEffectsTracker.Rd index 5249ca96..c57af16a 100644 --- a/man/resetRandomEffectsTracker.Rd +++ b/man/resetRandomEffectsTracker.Rd @@ -57,8 +57,8 @@ rfx_model$set_group_parameter_cov(sigma_xi_init) rfx_model$set_variance_prior_shape(sigma_xi_shape) rfx_model$set_variance_prior_scale(sigma_xi_scale) for (i in 1:3) { - rfx_model$sample_random_effect(rfx_dataset=rfx_dataset, residual=outcome, - rfx_tracker=rfx_tracker, rfx_samples=rfx_samples, + rfx_model$sample_random_effect(rfx_dataset=rfx_dataset, residual=outcome, + rfx_tracker=rfx_tracker, rfx_samples=rfx_samples, keep_sample=TRUE, global_variance=1.0, rng=rng) } resetRandomEffectsModel(rfx_model, rfx_samples, 0, 1.0) diff --git a/man/rootResetRandomEffectsModel.Rd b/man/rootResetRandomEffectsModel.Rd index c58a09e9..4c3cc2f7 100644 --- a/man/rootResetRandomEffectsModel.Rd +++ b/man/rootResetRandomEffectsModel.Rd @@ -63,8 +63,8 @@ rfx_model$set_group_parameter_cov(sigma_xi_init) rfx_model$set_variance_prior_shape(sigma_xi_shape) rfx_model$set_variance_prior_scale(sigma_xi_scale) for (i in 1:3) { - rfx_model$sample_random_effect(rfx_dataset=rfx_dataset, residual=outcome, - rfx_tracker=rfx_tracker, rfx_samples=rfx_samples, + rfx_model$sample_random_effect(rfx_dataset=rfx_dataset, residual=outcome, + rfx_tracker=rfx_tracker, rfx_samples=rfx_samples, keep_sample=TRUE, global_variance=1.0, rng=rng) } rootResetRandomEffectsModel(rfx_model, alpha_init, xi_init, sigma_alpha_init, diff --git a/man/rootResetRandomEffectsTracker.Rd b/man/rootResetRandomEffectsTracker.Rd index 8de2c514..6f2dc843 100644 --- a/man/rootResetRandomEffectsTracker.Rd +++ b/man/rootResetRandomEffectsTracker.Rd @@ -49,8 +49,8 @@ rfx_model$set_group_parameter_cov(sigma_xi_init) rfx_model$set_variance_prior_shape(sigma_xi_shape) rfx_model$set_variance_prior_scale(sigma_xi_scale) for (i in 1:3) { - rfx_model$sample_random_effect(rfx_dataset=rfx_dataset, residual=outcome, - rfx_tracker=rfx_tracker, rfx_samples=rfx_samples, + rfx_model$sample_random_effect(rfx_dataset=rfx_dataset, residual=outcome, + rfx_tracker=rfx_tracker, rfx_samples=rfx_samples, keep_sample=TRUE, global_variance=1.0, rng=rng) } rootResetRandomEffectsModel(rfx_model, alpha_init, xi_init, sigma_alpha_init, diff --git a/man/saveBARTModelToJson.Rd b/man/saveBARTModelToJson.Rd index a617532e..054af24e 100644 --- a/man/saveBARTModelToJson.Rd +++ b/man/saveBARTModelToJson.Rd @@ -20,9 +20,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 @@ -36,7 +36,7 @@ X_test <- X[test_inds,] X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, +bart_model <- bart(X_train = X_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) bart_json <- saveBARTModelToJson(bart_model) } diff --git a/man/saveBARTModelToJsonFile.Rd b/man/saveBARTModelToJsonFile.Rd index 46a3110e..62ef6ad7 100644 --- a/man/saveBARTModelToJsonFile.Rd +++ b/man/saveBARTModelToJsonFile.Rd @@ -22,9 +22,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 @@ -38,7 +38,7 @@ X_test <- X[test_inds,] X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, +bart_model <- bart(X_train = X_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) tmpjson <- tempfile(fileext = ".json") saveBARTModelToJsonFile(bart_model, file.path(tmpjson)) diff --git a/man/saveBARTModelToJsonString.Rd b/man/saveBARTModelToJsonString.Rd index c83f9e5d..10927c20 100644 --- a/man/saveBARTModelToJsonString.Rd +++ b/man/saveBARTModelToJsonString.Rd @@ -20,9 +20,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 @@ -36,7 +36,7 @@ X_test <- X[test_inds,] X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, +bart_model <- bart(X_train = X_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) bart_json_string <- saveBARTModelToJsonString(bart_model) } diff --git a/man/saveBCFModelToJson.Rd b/man/saveBCFModelToJson.Rd index ae2c286d..2c04d76c 100644 --- a/man/saveBCFModelToJson.Rd +++ b/man/saveBCFModelToJson.Rd @@ -20,21 +20,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -70,15 +70,15 @@ rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] mu_params <- list(sample_sigma2_leaf = TRUE) tau_params <- list(sample_sigma2_leaf = FALSE) -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, - rfx_group_ids_train = rfx_group_ids_train, - rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + rfx_basis_train = rfx_basis_train, X_test = X_test, + Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_test = rfx_basis_test, - num_gfr = 10, num_burnin = 0, num_mcmc = 10, - prognostic_forest_params = mu_params, + rfx_basis_test = rfx_basis_test, + num_gfr = 10, num_burnin = 0, num_mcmc = 10, + prognostic_forest_params = mu_params, treatment_effect_forest_params = tau_params) bcf_json <- saveBCFModelToJson(bcf_model) } diff --git a/man/saveBCFModelToJsonFile.Rd b/man/saveBCFModelToJsonFile.Rd index e6a9f0aa..584bbbba 100644 --- a/man/saveBCFModelToJsonFile.Rd +++ b/man/saveBCFModelToJsonFile.Rd @@ -22,21 +22,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -72,15 +72,15 @@ rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] mu_params <- list(sample_sigma2_leaf = TRUE) tau_params <- list(sample_sigma2_leaf = FALSE) -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, - rfx_group_ids_train = rfx_group_ids_train, - rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + rfx_basis_train = rfx_basis_train, X_test = X_test, + Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_test = rfx_basis_test, - num_gfr = 10, num_burnin = 0, num_mcmc = 10, - prognostic_forest_params = mu_params, + rfx_basis_test = rfx_basis_test, + num_gfr = 10, num_burnin = 0, num_mcmc = 10, + prognostic_forest_params = mu_params, treatment_effect_forest_params = tau_params) tmpjson <- tempfile(fileext = ".json") saveBCFModelToJsonFile(bcf_model, file.path(tmpjson)) diff --git a/man/saveBCFModelToJsonString.Rd b/man/saveBCFModelToJsonString.Rd index 4328e525..2182bbe3 100644 --- a/man/saveBCFModelToJsonString.Rd +++ b/man/saveBCFModelToJsonString.Rd @@ -20,21 +20,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -70,15 +70,15 @@ rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] mu_params <- list(sample_sigma2_leaf = TRUE) tau_params <- list(sample_sigma2_leaf = FALSE) -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, - rfx_group_ids_train = rfx_group_ids_train, - rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + rfx_basis_train = rfx_basis_train, X_test = X_test, + Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_test = rfx_basis_test, - num_gfr = 10, num_burnin = 0, num_mcmc = 10, - prognostic_forest_params = mu_params, + rfx_basis_test = rfx_basis_test, + num_gfr = 10, num_burnin = 0, num_mcmc = 10, + prognostic_forest_params = mu_params, treatment_effect_forest_params = tau_params) saveBCFModelToJsonString(bcf_model) } diff --git a/src/Makevars.in b/src/Makevars.in index 4eb970cb..850e2555 100644 --- a/src/Makevars.in +++ b/src/Makevars.in @@ -34,6 +34,7 @@ OBJECTS = \ data.o \ io.o \ leaf_model.o \ + ordinal_sampler.o \ partition_tracker.o \ random_effects.o \ tree.o diff --git a/src/Makevars.win.in b/src/Makevars.win.in index 95bff1dd..e9d54ab6 100644 --- a/src/Makevars.win.in +++ b/src/Makevars.win.in @@ -34,6 +34,7 @@ OBJECTS = \ data.o \ io.o \ leaf_model.o \ + ordinal_sampler.o \ partition_tracker.o \ random_effects.o \ tree.o diff --git a/src/R_data.cpp b/src/R_data.cpp index 39b77ab3..1396575f 100644 --- a/src/R_data.cpp +++ b/src/R_data.cpp @@ -5,6 +5,7 @@ #include #include #include +#include [[cpp11::register]] cpp11::external_pointer create_forest_dataset_cpp() { diff --git a/src/R_random_effects.cpp b/src/R_random_effects.cpp index e291121c..f627b3c5 100644 --- a/src/R_random_effects.cpp +++ b/src/R_random_effects.cpp @@ -7,7 +7,9 @@ #include #include #include +#include #include +#include [[cpp11::register]] cpp11::external_pointer rfx_container_cpp(int num_components, int num_groups) { diff --git a/src/cpp11.cpp b/src/cpp11.cpp index ef98aac0..881c5314 100644 --- a/src/cpp11.cpp +++ b/src/cpp11.cpp @@ -1157,18 +1157,18 @@ extern "C" SEXP _stochtree_compute_leaf_indices_cpp(SEXP forest_container, SEXP END_CPP11 } // sampler.cpp -void sample_gfr_one_iteration_cpp(cpp11::external_pointer data, cpp11::external_pointer residual, cpp11::external_pointer forest_samples, cpp11::external_pointer active_forest, cpp11::external_pointer tracker, cpp11::external_pointer split_prior, cpp11::external_pointer rng, cpp11::integers sweep_indices, cpp11::integers feature_types, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_model_scale_input, cpp11::doubles variable_weights, double a_forest, double b_forest, double global_variance, int leaf_model_int, bool keep_forest, int num_features_subsample, int num_threads); -extern "C" SEXP _stochtree_sample_gfr_one_iteration_cpp(SEXP data, SEXP residual, SEXP forest_samples, SEXP active_forest, SEXP tracker, SEXP split_prior, SEXP rng, SEXP sweep_indices, SEXP feature_types, SEXP cutpoint_grid_size, SEXP leaf_model_scale_input, SEXP variable_weights, SEXP a_forest, SEXP b_forest, SEXP global_variance, SEXP leaf_model_int, SEXP keep_forest, SEXP num_features_subsample, SEXP num_threads) { +void sample_gfr_one_iteration_cpp(cpp11::external_pointer data, cpp11::external_pointer residual, cpp11::external_pointer forest_samples, cpp11::external_pointer active_forest, cpp11::external_pointer tracker, cpp11::external_pointer split_prior, cpp11::external_pointer rng, cpp11::integers sweep_indices, cpp11::integers feature_types, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_model_scale_input, cpp11::doubles variable_weights, double a_forest, double b_forest, double global_variance, int leaf_model_int, bool keep_forest, int num_features_subsample); +extern "C" SEXP _stochtree_sample_gfr_one_iteration_cpp(SEXP data, SEXP residual, SEXP forest_samples, SEXP active_forest, SEXP tracker, SEXP split_prior, SEXP rng, SEXP sweep_indices, SEXP feature_types, SEXP cutpoint_grid_size, SEXP leaf_model_scale_input, SEXP variable_weights, SEXP a_forest, SEXP b_forest, SEXP global_variance, SEXP leaf_model_int, SEXP keep_forest, SEXP num_features_subsample) { BEGIN_CPP11 - sample_gfr_one_iteration_cpp(cpp11::as_cpp>>(data), cpp11::as_cpp>>(residual), cpp11::as_cpp>>(forest_samples), cpp11::as_cpp>>(active_forest), cpp11::as_cpp>>(tracker), cpp11::as_cpp>>(split_prior), cpp11::as_cpp>>(rng), cpp11::as_cpp>(sweep_indices), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_model_scale_input), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(a_forest), cpp11::as_cpp>(b_forest), cpp11::as_cpp>(global_variance), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(keep_forest), cpp11::as_cpp>(num_features_subsample), cpp11::as_cpp>(num_threads)); + sample_gfr_one_iteration_cpp(cpp11::as_cpp>>(data), cpp11::as_cpp>>(residual), cpp11::as_cpp>>(forest_samples), cpp11::as_cpp>>(active_forest), cpp11::as_cpp>>(tracker), cpp11::as_cpp>>(split_prior), cpp11::as_cpp>>(rng), cpp11::as_cpp>(sweep_indices), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_model_scale_input), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(a_forest), cpp11::as_cpp>(b_forest), cpp11::as_cpp>(global_variance), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(keep_forest), cpp11::as_cpp>(num_features_subsample)); return R_NilValue; END_CPP11 } // sampler.cpp -void sample_mcmc_one_iteration_cpp(cpp11::external_pointer data, cpp11::external_pointer residual, cpp11::external_pointer forest_samples, cpp11::external_pointer active_forest, cpp11::external_pointer tracker, cpp11::external_pointer split_prior, cpp11::external_pointer rng, cpp11::integers sweep_indices, cpp11::integers feature_types, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_model_scale_input, cpp11::doubles variable_weights, double a_forest, double b_forest, double global_variance, int leaf_model_int, bool keep_forest, int num_threads); -extern "C" SEXP _stochtree_sample_mcmc_one_iteration_cpp(SEXP data, SEXP residual, SEXP forest_samples, SEXP active_forest, SEXP tracker, SEXP split_prior, SEXP rng, SEXP sweep_indices, SEXP feature_types, SEXP cutpoint_grid_size, SEXP leaf_model_scale_input, SEXP variable_weights, SEXP a_forest, SEXP b_forest, SEXP global_variance, SEXP leaf_model_int, SEXP keep_forest, SEXP num_threads) { +void sample_mcmc_one_iteration_cpp(cpp11::external_pointer data, cpp11::external_pointer residual, cpp11::external_pointer forest_samples, cpp11::external_pointer active_forest, cpp11::external_pointer tracker, cpp11::external_pointer split_prior, cpp11::external_pointer rng, cpp11::integers sweep_indices, cpp11::integers feature_types, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_model_scale_input, cpp11::doubles variable_weights, double a_forest, double b_forest, double global_variance, int leaf_model_int, bool keep_forest); +extern "C" SEXP _stochtree_sample_mcmc_one_iteration_cpp(SEXP data, SEXP residual, SEXP forest_samples, SEXP active_forest, SEXP tracker, SEXP split_prior, SEXP rng, SEXP sweep_indices, SEXP feature_types, SEXP cutpoint_grid_size, SEXP leaf_model_scale_input, SEXP variable_weights, SEXP a_forest, SEXP b_forest, SEXP global_variance, SEXP leaf_model_int, SEXP keep_forest) { BEGIN_CPP11 - sample_mcmc_one_iteration_cpp(cpp11::as_cpp>>(data), cpp11::as_cpp>>(residual), cpp11::as_cpp>>(forest_samples), cpp11::as_cpp>>(active_forest), cpp11::as_cpp>>(tracker), cpp11::as_cpp>>(split_prior), cpp11::as_cpp>>(rng), cpp11::as_cpp>(sweep_indices), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_model_scale_input), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(a_forest), cpp11::as_cpp>(b_forest), cpp11::as_cpp>(global_variance), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(keep_forest), cpp11::as_cpp>(num_threads)); + sample_mcmc_one_iteration_cpp(cpp11::as_cpp>>(data), cpp11::as_cpp>>(residual), cpp11::as_cpp>>(forest_samples), cpp11::as_cpp>>(active_forest), cpp11::as_cpp>>(tracker), cpp11::as_cpp>>(split_prior), cpp11::as_cpp>>(rng), cpp11::as_cpp>(sweep_indices), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_model_scale_input), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(a_forest), cpp11::as_cpp>(b_forest), cpp11::as_cpp>(global_variance), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(keep_forest)); return R_NilValue; END_CPP11 } @@ -1281,6 +1281,83 @@ extern "C" SEXP _stochtree_sample_without_replacement_integer_cpp(SEXP populatio return cpp11::as_sexp(sample_without_replacement_integer_cpp(cpp11::as_cpp>(population_vector), cpp11::as_cpp>(sampling_probs), cpp11::as_cpp>(sample_size))); END_CPP11 } +// sampler.cpp +void ordinal_aux_data_initialize_cpp(cpp11::external_pointer tracker_ptr, StochTree::data_size_t num_observations, int n_levels); +extern "C" SEXP _stochtree_ordinal_aux_data_initialize_cpp(SEXP tracker_ptr, SEXP num_observations, SEXP n_levels) { + BEGIN_CPP11 + ordinal_aux_data_initialize_cpp(cpp11::as_cpp>>(tracker_ptr), cpp11::as_cpp>(num_observations), cpp11::as_cpp>(n_levels)); + return R_NilValue; + END_CPP11 +} +// sampler.cpp +double ordinal_aux_data_get_cpp(cpp11::external_pointer tracker_ptr, int type_idx, StochTree::data_size_t obs_idx); +extern "C" SEXP _stochtree_ordinal_aux_data_get_cpp(SEXP tracker_ptr, SEXP type_idx, SEXP obs_idx) { + BEGIN_CPP11 + return cpp11::as_sexp(ordinal_aux_data_get_cpp(cpp11::as_cpp>>(tracker_ptr), cpp11::as_cpp>(type_idx), cpp11::as_cpp>(obs_idx))); + END_CPP11 +} +// sampler.cpp +void ordinal_aux_data_set_cpp(cpp11::external_pointer tracker_ptr, int type_idx, StochTree::data_size_t obs_idx, double value); +extern "C" SEXP _stochtree_ordinal_aux_data_set_cpp(SEXP tracker_ptr, SEXP type_idx, SEXP obs_idx, SEXP value) { + BEGIN_CPP11 + ordinal_aux_data_set_cpp(cpp11::as_cpp>>(tracker_ptr), cpp11::as_cpp>(type_idx), cpp11::as_cpp>(obs_idx), cpp11::as_cpp>(value)); + return R_NilValue; + END_CPP11 +} +// sampler.cpp +cpp11::writable::doubles ordinal_aux_data_get_vector_cpp(cpp11::external_pointer tracker_ptr, int type_idx); +extern "C" SEXP _stochtree_ordinal_aux_data_get_vector_cpp(SEXP tracker_ptr, SEXP type_idx) { + BEGIN_CPP11 + return cpp11::as_sexp(ordinal_aux_data_get_vector_cpp(cpp11::as_cpp>>(tracker_ptr), cpp11::as_cpp>(type_idx))); + END_CPP11 +} +// sampler.cpp +void ordinal_aux_data_set_vector_cpp(cpp11::external_pointer tracker_ptr, int type_idx, cpp11::doubles values); +extern "C" SEXP _stochtree_ordinal_aux_data_set_vector_cpp(SEXP tracker_ptr, SEXP type_idx, SEXP values) { + BEGIN_CPP11 + ordinal_aux_data_set_vector_cpp(cpp11::as_cpp>>(tracker_ptr), cpp11::as_cpp>(type_idx), cpp11::as_cpp>(values)); + return R_NilValue; + END_CPP11 +} +// sampler.cpp +void ordinal_aux_data_update_cumsum_exp_cpp(cpp11::external_pointer tracker_ptr); +extern "C" SEXP _stochtree_ordinal_aux_data_update_cumsum_exp_cpp(SEXP tracker_ptr) { + BEGIN_CPP11 + ordinal_aux_data_update_cumsum_exp_cpp(cpp11::as_cpp>>(tracker_ptr)); + return R_NilValue; + END_CPP11 +} +// sampler.cpp +cpp11::external_pointer ordinal_sampler_cpp(); +extern "C" SEXP _stochtree_ordinal_sampler_cpp() { + BEGIN_CPP11 + return cpp11::as_sexp(ordinal_sampler_cpp()); + END_CPP11 +} +// sampler.cpp +void ordinal_sampler_update_latent_variables_cpp(cpp11::external_pointer sampler_ptr, cpp11::external_pointer data_ptr, cpp11::external_pointer outcome_ptr, cpp11::external_pointer tracker_ptr, cpp11::external_pointer rng_ptr); +extern "C" SEXP _stochtree_ordinal_sampler_update_latent_variables_cpp(SEXP sampler_ptr, SEXP data_ptr, SEXP outcome_ptr, SEXP tracker_ptr, SEXP rng_ptr) { + BEGIN_CPP11 + ordinal_sampler_update_latent_variables_cpp(cpp11::as_cpp>>(sampler_ptr), cpp11::as_cpp>>(data_ptr), cpp11::as_cpp>>(outcome_ptr), cpp11::as_cpp>>(tracker_ptr), cpp11::as_cpp>>(rng_ptr)); + return R_NilValue; + END_CPP11 +} +// sampler.cpp +void ordinal_sampler_update_gamma_params_cpp(cpp11::external_pointer sampler_ptr, cpp11::external_pointer data_ptr, cpp11::external_pointer outcome_ptr, cpp11::external_pointer tracker_ptr, double alpha_gamma, double beta_gamma, double gamma_0, cpp11::external_pointer rng_ptr); +extern "C" SEXP _stochtree_ordinal_sampler_update_gamma_params_cpp(SEXP sampler_ptr, SEXP data_ptr, SEXP outcome_ptr, SEXP tracker_ptr, SEXP alpha_gamma, SEXP beta_gamma, SEXP gamma_0, SEXP rng_ptr) { + BEGIN_CPP11 + ordinal_sampler_update_gamma_params_cpp(cpp11::as_cpp>>(sampler_ptr), cpp11::as_cpp>>(data_ptr), cpp11::as_cpp>>(outcome_ptr), cpp11::as_cpp>>(tracker_ptr), cpp11::as_cpp>(alpha_gamma), cpp11::as_cpp>(beta_gamma), cpp11::as_cpp>(gamma_0), cpp11::as_cpp>>(rng_ptr)); + return R_NilValue; + END_CPP11 +} +// sampler.cpp +void ordinal_sampler_update_cumsum_exp_cpp(cpp11::external_pointer sampler_ptr, cpp11::external_pointer tracker_ptr); +extern "C" SEXP _stochtree_ordinal_sampler_update_cumsum_exp_cpp(SEXP sampler_ptr, SEXP tracker_ptr) { + BEGIN_CPP11 + ordinal_sampler_update_cumsum_exp_cpp(cpp11::as_cpp>>(sampler_ptr), cpp11::as_cpp>>(tracker_ptr)); + return R_NilValue; + END_CPP11 +} // serialization.cpp cpp11::external_pointer init_json_cpp(); extern "C" SEXP _stochtree_init_json_cpp() { @@ -1711,6 +1788,16 @@ static const R_CallMethodDef CallEntries[] = { {"_stochtree_num_split_nodes_forest_container_cpp", (DL_FUNC) &_stochtree_num_split_nodes_forest_container_cpp, 3}, {"_stochtree_num_trees_active_forest_cpp", (DL_FUNC) &_stochtree_num_trees_active_forest_cpp, 1}, {"_stochtree_num_trees_forest_container_cpp", (DL_FUNC) &_stochtree_num_trees_forest_container_cpp, 1}, + {"_stochtree_ordinal_aux_data_get_cpp", (DL_FUNC) &_stochtree_ordinal_aux_data_get_cpp, 3}, + {"_stochtree_ordinal_aux_data_get_vector_cpp", (DL_FUNC) &_stochtree_ordinal_aux_data_get_vector_cpp, 2}, + {"_stochtree_ordinal_aux_data_initialize_cpp", (DL_FUNC) &_stochtree_ordinal_aux_data_initialize_cpp, 3}, + {"_stochtree_ordinal_aux_data_set_cpp", (DL_FUNC) &_stochtree_ordinal_aux_data_set_cpp, 4}, + {"_stochtree_ordinal_aux_data_set_vector_cpp", (DL_FUNC) &_stochtree_ordinal_aux_data_set_vector_cpp, 3}, + {"_stochtree_ordinal_aux_data_update_cumsum_exp_cpp", (DL_FUNC) &_stochtree_ordinal_aux_data_update_cumsum_exp_cpp, 1}, + {"_stochtree_ordinal_sampler_cpp", (DL_FUNC) &_stochtree_ordinal_sampler_cpp, 0}, + {"_stochtree_ordinal_sampler_update_cumsum_exp_cpp", (DL_FUNC) &_stochtree_ordinal_sampler_update_cumsum_exp_cpp, 2}, + {"_stochtree_ordinal_sampler_update_gamma_params_cpp", (DL_FUNC) &_stochtree_ordinal_sampler_update_gamma_params_cpp, 8}, + {"_stochtree_ordinal_sampler_update_latent_variables_cpp", (DL_FUNC) &_stochtree_ordinal_sampler_update_latent_variables_cpp, 5}, {"_stochtree_overwrite_column_vector_cpp", (DL_FUNC) &_stochtree_overwrite_column_vector_cpp, 2}, {"_stochtree_parent_node_forest_container_cpp", (DL_FUNC) &_stochtree_parent_node_forest_container_cpp, 4}, {"_stochtree_predict_active_forest_cpp", (DL_FUNC) &_stochtree_predict_active_forest_cpp, 2}, @@ -1776,8 +1863,8 @@ static const R_CallMethodDef CallEntries[] = { {"_stochtree_rng_cpp", (DL_FUNC) &_stochtree_rng_cpp, 1}, {"_stochtree_root_reset_active_forest_cpp", (DL_FUNC) &_stochtree_root_reset_active_forest_cpp, 1}, {"_stochtree_root_reset_rfx_tracker_cpp", (DL_FUNC) &_stochtree_root_reset_rfx_tracker_cpp, 4}, - {"_stochtree_sample_gfr_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_gfr_one_iteration_cpp, 19}, - {"_stochtree_sample_mcmc_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_mcmc_one_iteration_cpp, 18}, + {"_stochtree_sample_gfr_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_gfr_one_iteration_cpp, 18}, + {"_stochtree_sample_mcmc_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_mcmc_one_iteration_cpp, 17}, {"_stochtree_sample_sigma2_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_sigma2_one_iteration_cpp, 5}, {"_stochtree_sample_tau_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_tau_one_iteration_cpp, 4}, {"_stochtree_sample_without_replacement_integer_cpp", (DL_FUNC) &_stochtree_sample_without_replacement_integer_cpp, 3}, diff --git a/src/cutpoint_candidates.cpp b/src/cutpoint_candidates.cpp index e43b8219..4a0845c7 100644 --- a/src/cutpoint_candidates.cpp +++ b/src/cutpoint_candidates.cpp @@ -2,6 +2,7 @@ #include #include +#include namespace StochTree { diff --git a/src/data.cpp b/src/data.cpp index e48e9255..cd2913cf 100644 --- a/src/data.cpp +++ b/src/data.cpp @@ -1,6 +1,7 @@ /*! Copyright (c) 2024 by stochtree authors */ #include #include +#include namespace StochTree { diff --git a/src/forest.cpp b/src/forest.cpp index 968fe95c..02757aa7 100644 --- a/src/forest.cpp +++ b/src/forest.cpp @@ -7,7 +7,9 @@ #include #include #include +#include #include +#include [[cpp11::register]] cpp11::external_pointer active_forest_cpp(int num_trees, int output_dimension = 1, bool is_leaf_constant = true, bool is_exponentiated = false) { diff --git a/src/io.cpp b/src/io.cpp index 50774d9b..1324957f 100644 --- a/src/io.cpp +++ b/src/io.cpp @@ -7,7 +7,9 @@ #include #include +#include #include +#include namespace StochTree { diff --git a/src/kernel.cpp b/src/kernel.cpp index 88f12c53..6b5867bb 100644 --- a/src/kernel.cpp +++ b/src/kernel.cpp @@ -3,6 +3,8 @@ #include #include #include +#include +#include typedef Eigen::Map> DoubleMatrixType; typedef Eigen::Map> IntMatrixType; diff --git a/src/leaf_model.cpp b/src/leaf_model.cpp index 78d8da76..3f39fba5 100644 --- a/src/leaf_model.cpp +++ b/src/leaf_model.cpp @@ -1,5 +1,7 @@ #include #include +#include +#include namespace StochTree { diff --git a/src/ordinal_sampler.cpp b/src/ordinal_sampler.cpp index 27b11f63..19a7c6b5 100644 --- a/src/ordinal_sampler.cpp +++ b/src/ordinal_sampler.cpp @@ -16,10 +16,10 @@ void OrdinalSampler::UpdateLatentVariables(ForestDataset& dataset, Eigen::Vector const std::vector& gamma = tracker.GetOrdinalAuxDataVector(2); // gamma cutpoints const std::vector& lambda_hat = tracker.GetOrdinalAuxDataVector(1); // forest predictions: lambda_hat_i = sum_t lambda_t(x_i) std::vector& Z = tracker.GetOrdinalAuxDataVector(0); // latent variables: z_i ~ TExp(e^{gamma[y_i] + lambda_hat_i}; 0, 1) - + int K = gamma.size() + 1; // Number of ordinal categories - int N = dataset.NumObservations(); - + int N = dataset.NumObservations(); + // Update truncated exponentials (stored in latent auxiliary data slot 0) // z_i ~ TExp(rate = e^{gamma[y_i] + lambda_hat_i}; 0, 1) // where y_i is the ordinal outcome for observation i: make sure y_i converted to {0, 1, ..., K-1} @@ -27,7 +27,7 @@ void OrdinalSampler::UpdateLatentVariables(ForestDataset& dataset, Eigen::Vector // If y_i = K-1 (last category), then we set z_i = 1.0 deterministically just for bookkeeping, we don't need it // We only need to sample latent z_i for y_i < K-1 (as z_i is only used in the likelihood for y_i < K-1) for (int i = 0; i < N; i++) { - int y = static_cast(outcome(i)); + int y = static_cast(outcome(i)); if (y == K - 1) { Z[i] = 1.0; } else { @@ -44,14 +44,14 @@ void OrdinalSampler::UpdateGammaParams(ForestDataset& dataset, Eigen::VectorXd& std::vector& gamma = tracker.GetOrdinalAuxDataVector(2); // cutpoints gamma_k's const std::vector& Z = tracker.GetOrdinalAuxDataVector(0); // latent variables z_i's const std::vector& lambda_hat = tracker.GetOrdinalAuxDataVector(1); // forest predictions: lambda_hat_i = sum_t lambda_t(x_i) - + int K = gamma.size() + 1; // Number of ordinal categories int N = dataset.NumObservations(); // Compute sufficient statistics A[k] and B[k] for gamma[k] update std::vector A(K - 1, 0.0); std::vector B(K - 1, 0.0); - + for (int i = 0; i < N; i++) { int y = static_cast(outcome(i)); if (y < K - 1) { @@ -62,16 +62,16 @@ void OrdinalSampler::UpdateGammaParams(ForestDataset& dataset, Eigen::VectorXd& B[k] += std::exp(lambda_hat[i]); } } - - // Update gamma parameters using log-gamma sampling + + // Update gamma parameters using log-gamma sampling // First sample all gamma parameters - for (int k = 0; k < static_cast(gamma.size()); k++) { + for (int k = 0; k < static_cast(gamma.size()); k++) { double shape = A[k] + alpha_gamma; - double rate = B[k] + beta_gamma; + double rate = B[k] + beta_gamma; double gamma_sample = gamma_sampler_.Sample(shape, rate, gen); gamma[k] = std::log(gamma_sample); } - + // Set the first gamma parameter to gamma_0 (e.g., 0) for identifiability gamma[0] = gamma_0; } @@ -80,7 +80,7 @@ void OrdinalSampler::UpdateCumulativeExpSums(ForestTracker& tracker) { // Get auxiliary data vectors const std::vector& gamma = tracker.GetOrdinalAuxDataVector(2); // cutpoints gamma_k's std::vector& seg = tracker.GetOrdinalAuxDataVector(3); // seg_k = sum_{j=0}^{k-1} exp(gamma_j) - + // Update seg (sum of exponentials of gamma cutpoints) for (int j = 0; j < static_cast(seg.size()); j++) { if (j == 0) { diff --git a/src/partition_tracker.cpp b/src/partition_tracker.cpp index 8359faed..bb35efd7 100644 --- a/src/partition_tracker.cpp +++ b/src/partition_tracker.cpp @@ -6,6 +6,12 @@ #include #include +#include +#include +#include +#include +#include + namespace StochTree { ForestTracker::ForestTracker(Eigen::MatrixXd& covariates, std::vector& feature_types, int num_trees, int num_observations) { @@ -28,15 +34,15 @@ void ForestTracker::ReconstituteFromForest(TreeEnsemble& forest, ForestDataset& // (1) Updates the residual by adding currently cached tree predictions and subtracting predictions from new tree // (2) Updates sample_node_mapper_, sample_pred_mapper_, and sum_predictions_ based on the new forest UpdateSampleTrackersResidual(forest, dataset, residual, is_mean_model); - + // Since GFR always starts over from root, this data structure can always simply be reset Eigen::MatrixXd& covariates = dataset.GetCovariates(); sorted_node_sample_tracker_.reset(new SortedNodeSampleTracker(presort_container_.get(), covariates, feature_types_)); - + // Reconstitute each of the remaining data structures in the tracker based on splits in the ensemble // UnsortedNodeSampleTracker unsorted_node_sample_tracker_->ReconstituteFromForest(forest, dataset); - + } void ForestTracker::ResetRoot(Eigen::MatrixXd& covariates, std::vector& feature_types, int32_t tree_num) { @@ -156,7 +162,7 @@ void ForestTracker::UpdateSampleTrackersResidualInternalBasis(TreeEnsemble& fore for (int j = 0; j < num_trees_; j++) { // Query the previously cached prediction for tree j, observation i prev_tree_pred = sample_pred_mapper_->GetPred(i, j); - + // Compute the new prediction for tree j, observation i new_tree_pred = 0.0; Tree* tree = forest.GetTree(j); @@ -164,7 +170,7 @@ void ForestTracker::UpdateSampleTrackersResidualInternalBasis(TreeEnsemble& fore for (int32_t k = 0; k < output_dim; k++) { new_tree_pred += tree->LeafValue(nidx, k) * basis(i, k); } - + if (is_mean_model) { // Adjust the residual by adding the previous prediction and subtracting the new prediction new_resid = residual.GetElement(i) - new_tree_pred + prev_tree_pred; @@ -202,7 +208,7 @@ void ForestTracker::UpdateSampleTrackersResidualInternalNoBasis(TreeEnsemble& fo Tree* tree = forest.GetTree(j); std::int32_t nidx = EvaluateTree(*tree, covariates, i); new_tree_pred = tree->LeafValue(nidx, 0); - + if (is_mean_model) { // Adjust the residual by adding the previous prediction and subtracting the new prediction new_resid = residual.GetElement(i) - new_tree_pred + prev_tree_pred; @@ -211,7 +217,7 @@ void ForestTracker::UpdateSampleTrackersResidualInternalNoBasis(TreeEnsemble& fo new_weight = std::log(dataset.VarWeightValue(i)) + new_tree_pred - prev_tree_pred; dataset.SetVarWeightValue(i, new_weight, true); } - + // Update the sample node mapper and sample prediction mapper sample_node_mapper_->SetNodeId(i, j, nidx); sample_pred_mapper_->SetPred(i, j, new_tree_pred); @@ -280,7 +286,7 @@ void ForestTracker::AddSplit(Eigen::MatrixXd& covariates, TreeSplit& split, int3 sample_node_mapper_->AddSplit(covariates, split, split_feature, tree_id, split_node_id, left_node_id, right_node_id); unsorted_node_sample_tracker_->PartitionTreeNode(covariates, tree_id, split_node_id, left_node_id, right_node_id, split_feature, split); if (keep_sorted) { - sorted_node_sample_tracker_->PartitionNode(covariates, split_node_id, split_feature, split, num_threads); + sorted_node_sample_tracker_->PartitionNode(covariates, split_node_id, split_feature, split); } } @@ -346,21 +352,21 @@ void FeatureUnsortedPartition::ReconstituteFromTree(Tree& tree, ForestDataset& d CHECK_EQ(num_deleted_nodes_, 0); data_size_t n = dataset.NumObservations(); CHECK_EQ(indices_.size(), n); - + // Extract covariates Eigen::MatrixXd& covariates = dataset.GetCovariates(); // Set node counters num_nodes_ = tree.NumNodes(); num_deleted_nodes_ = tree.NumDeletedNodes(); - + // Resize tracking vectors node_begin_.resize(num_nodes_); node_length_.resize(num_nodes_); parent_nodes_.resize(num_nodes_); left_nodes_.resize(num_nodes_); right_nodes_.resize(num_nodes_); - + // Unpack tree's splits into this data structure bool is_deleted; TreeNodeType node_type; @@ -399,11 +405,11 @@ void FeatureUnsortedPartition::ReconstituteFromTree(Tree& tree, ForestDataset& d } else { continue; } - // Partition the node indices + // Partition the node indices auto node_begin = (indices_.begin() + node_begin_[i]); auto node_end = (indices_.begin() + node_begin_[i] + node_length_[i]); auto right_node_begin = std::stable_partition(node_begin, node_end, [&](int row) { return split_rule.SplitTrue(covariates(row, split_index)); }); - + // Determine the number of true and false elements node_begin = (indices_.begin() + node_begin_[i]); num_true = std::distance(node_begin, right_node_begin); @@ -415,7 +421,7 @@ void FeatureUnsortedPartition::ReconstituteFromTree(Tree& tree, ForestDataset& d parent_nodes_[left_nodes_[i]] = i; left_nodes_[left_nodes_[i]] = StochTree::Tree::kInvalidNodeId; left_nodes_[right_nodes_[i]] = StochTree::Tree::kInvalidNodeId; - + // Add right node tracking information node_begin_[right_nodes_[i]] = node_start_idx + num_true; node_length_[right_nodes_[i]] = num_false; @@ -455,11 +461,11 @@ void FeatureUnsortedPartition::PartitionNode(Eigen::MatrixXd& covariates, int no data_size_t node_start_idx = node_begin_[node_id]; data_size_t num_node_elements = node_length_[node_id]; - // Partition the node indices + // Partition the node indices auto node_begin = (indices_.begin() + node_begin_[node_id]); auto node_end = (indices_.begin() + node_begin_[node_id] + node_length_[node_id]); auto right_node_begin = std::stable_partition(node_begin, node_end, [&](int row) { return split.SplitTrue(covariates(row, feature_split)); }); - + // Determine the number of true and false elements node_begin = (indices_.begin() + node_begin_[node_id]); data_size_t num_true = std::distance(node_begin, right_node_begin); @@ -474,11 +480,11 @@ void FeatureUnsortedPartition::PartitionNode(Eigen::MatrixXd& covariates, int no data_size_t node_start_idx = node_begin_[node_id]; data_size_t num_node_elements = node_length_[node_id]; - // Partition the node indices + // Partition the node indices auto node_begin = (indices_.begin() + node_begin_[node_id]); auto node_end = (indices_.begin() + node_begin_[node_id] + node_length_[node_id]); auto right_node_begin = std::stable_partition(node_begin, node_end, [&](int row) { return RowSplitLeft(covariates, row, feature_split, split_value); }); - + // Determine the number of true and false elements node_begin = (indices_.begin() + node_begin_[node_id]); data_size_t num_true = std::distance(node_begin, right_node_begin); @@ -493,11 +499,11 @@ void FeatureUnsortedPartition::PartitionNode(Eigen::MatrixXd& covariates, int no data_size_t node_start_idx = node_begin_[node_id]; data_size_t num_node_elements = node_length_[node_id]; - // Partition the node indices + // Partition the node indices auto node_begin = (indices_.begin() + node_begin_[node_id]); auto node_end = (indices_.begin() + node_begin_[node_id] + node_length_[node_id]); auto right_node_begin = std::stable_partition(node_begin, node_end, [&](int row) { return RowSplitLeft(covariates, row, feature_split, category_list); }); - + // Determine the number of true and false elements node_begin = (indices_.begin() + node_begin_[node_id]); data_size_t num_true = std::distance(node_begin, right_node_begin); @@ -536,7 +542,7 @@ void FeatureUnsortedPartition::ExpandNodeTrackingVectors(int node_id, int left_n parent_nodes_[left_node_id] = node_id; left_nodes_[left_node_id] = StochTree::Tree::kInvalidNodeId; left_nodes_[right_node_id] = StochTree::Tree::kInvalidNodeId; - + // Add right node tracking information right_nodes_[node_id] = right_node_id; node_begin_[right_node_id] = node_start_idx + num_left; @@ -578,7 +584,7 @@ bool FeatureUnsortedPartition::RightNodeIsLeaf(int node_id) { } void FeatureUnsortedPartition::PruneNodeToLeaf(int node_id) { - // No need to "un-sift" the indices in the newly pruned node, we don't depend on the indices + // No need to "un-sift" the indices in the newly pruned node, we don't depend on the indices // having any type of sort order, so the indices will simply be "re-sifted" if the node is later partitioned if (IsLeaf(node_id)) return; if (!LeftNodeIsLeaf(node_id)) { @@ -614,7 +620,7 @@ std::vector FeatureUnsortedPartition::NodeIndices(int node_id) { void FeaturePresortPartition::AddLeftRightNodes(data_size_t left_node_begin, data_size_t left_node_size, data_size_t right_node_begin, data_size_t right_node_size) { // Assumes that we aren't pruning / deleting nodes, since this is for use with recursive algorithms - + // Add the left ("true") node to the offset size vector node_offset_sizes_.emplace_back(left_node_begin, left_node_size); // Add the right ("false") node to the offset size vector @@ -627,11 +633,11 @@ void FeaturePresortPartition::SplitFeature(Eigen::MatrixXd& covariates, int32_t data_size_t node_end_idx = NodeEnd(node_id); data_size_t num_node_elements = NodeSize(node_id); - // Partition the node indices + // Partition the node indices auto node_begin = (feature_sort_indices_.begin() + node_start_idx); auto node_end = (feature_sort_indices_.begin() + node_end_idx); auto right_node_begin = std::stable_partition(node_begin, node_end, [&](int row) { return split.SplitTrue(covariates(row, feature_index)); }); - + // Add the left and right nodes to the offset size vector node_begin = (feature_sort_indices_.begin() + node_start_idx); data_size_t num_true = std::distance(node_begin, right_node_begin); @@ -645,11 +651,11 @@ void FeaturePresortPartition::SplitFeatureNumeric(Eigen::MatrixXd& covariates, i data_size_t node_end_idx = NodeEnd(node_id); data_size_t num_node_elements = NodeSize(node_id); - // Partition the node indices + // Partition the node indices auto node_begin = (feature_sort_indices_.begin() + node_start_idx); auto node_end = (feature_sort_indices_.begin() + node_end_idx); auto right_node_begin = std::stable_partition(node_begin, node_end, [&](int row) { return RowSplitLeft(covariates, row, feature_index, split_value); }); - + // Add the left and right nodes to the offset size vector node_begin = (feature_sort_indices_.begin() + node_start_idx); data_size_t num_true = std::distance(node_begin, right_node_begin); @@ -663,11 +669,11 @@ void FeaturePresortPartition::SplitFeatureCategorical(Eigen::MatrixXd& covariate data_size_t node_end_idx = NodeEnd(node_id); data_size_t num_node_elements = NodeSize(node_id); - // Partition the node indices + // Partition the node indices auto node_begin = (feature_sort_indices_.begin() + node_start_idx); auto node_end = (feature_sort_indices_.begin() + node_end_idx); auto right_node_begin = std::stable_partition(node_begin, node_end, [&](int row) { return RowSplitLeft(covariates, row, feature_index, category_list); }); - + // Add the left and right nodes to the offset size vector node_begin = (feature_sort_indices_.begin() + node_start_idx); data_size_t num_true = std::distance(node_begin, right_node_begin); @@ -696,12 +702,12 @@ std::vector FeaturePresortPartition::NodeIndices(int node_id) { return out; } - // ============================================================================ // ORDINAL AUXILIARY DATA METHODS // ============================================================================ double ForestTracker::GetOrdinalAuxData(int type_idx, data_size_t obs_idx) const { + // CHECK(IsValidOrdinalIndex(type_idx, obs_idx)); return ordinal_aux_data_vec_[type_idx][obs_idx]; } @@ -710,10 +716,12 @@ void ForestTracker::InitializeOrdinalAuxData(data_size_t num_observations, int n } void ForestTracker::SetOrdinalAuxData(int type_idx, data_size_t obs_idx, double value) { + // CHECK(IsValidOrdinalIndex(type_idx, obs_idx)); ordinal_aux_data_vec_[type_idx][obs_idx] = value; } std::vector& ForestTracker::GetOrdinalAuxDataVector(int type_idx) { + // CHECK(IsValidOrdinalType(type_idx)); return ordinal_aux_data_vec_[type_idx]; } @@ -735,4 +743,16 @@ void ForestTracker::ResizeOrdinalAuxData(data_size_t num_observations, int n_lev } } +// bool ForestTracker::IsValidOrdinalType(int type_idx) const { +// return (type_idx >= 0 && type_idx < static_cast(ordinal_aux_data_vec_.size()) && +// !ordinal_aux_data_vec_.empty()); +// } + +// bool ForestTracker::IsValidOrdinalIndex(int type_idx, data_size_t obs_idx) const { +// if (!IsValidOrdinalType(type_idx)) { +// return false; +// } +// return (obs_idx >= 0 && obs_idx < ordinal_aux_data_vec_[type_idx].size()); +// } + } // namespace StochTree diff --git a/src/py_stochtree.cpp b/src/py_stochtree.cpp index 950caeb8..34931fa9 100644 --- a/src/py_stochtree.cpp +++ b/src/py_stochtree.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #define STRINGIFY(x) #x #define MACRO_STRINGIFY(x) STRINGIFY(x) @@ -1077,7 +1078,7 @@ class ForestSamplerCpp { void SampleOneIteration(ForestContainerCpp& forest_samples, ForestCpp& forest, ForestDatasetCpp& dataset, ResidualCpp& residual, RngCpp& rng, py::array_t feature_types, py::array_t sweep_update_indices, int cutpoint_grid_size, py::array_t leaf_model_scale_input, py::array_t variable_weights, double a_forest, double b_forest, double global_variance, - int leaf_model_int, int num_features_subsample, bool keep_forest = true, bool gfr = true, int num_threads = -1) { + int leaf_model_int, int num_features_subsample, bool keep_forest = true, bool gfr = true) { // Refactoring completely out of the Python interface. // Intention to refactor out of the C++ interface in the future. bool pre_initialized = true; @@ -1139,23 +1140,23 @@ class ForestSamplerCpp { std::mt19937* rng_ptr = rng.GetRng(); if (gfr) { if (model_type == StochTree::ModelType::kConstantLeafGaussian) { - StochTree::GFRSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_features_subsample, num_threads); + StochTree::GFRSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_features_subsample); } else if (model_type == StochTree::ModelType::kUnivariateRegressionLeafGaussian) { - StochTree::GFRSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_features_subsample, num_threads); + StochTree::GFRSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_features_subsample); } else if (model_type == StochTree::ModelType::kMultivariateRegressionLeafGaussian) { - StochTree::GFRSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_features_subsample, num_threads, num_basis); + StochTree::GFRSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_features_subsample, num_basis); } else if (model_type == StochTree::ModelType::kLogLinearVariance) { - StochTree::GFRSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, false, num_features_subsample, num_threads); + StochTree::GFRSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, false, num_features_subsample); } } else { if (model_type == StochTree::ModelType::kConstantLeafGaussian) { - StochTree::MCMCSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, keep_forest, pre_initialized, true, num_threads); + StochTree::MCMCSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, keep_forest, pre_initialized, true); } else if (model_type == StochTree::ModelType::kUnivariateRegressionLeafGaussian) { - StochTree::MCMCSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, keep_forest, pre_initialized, true, num_threads); + StochTree::MCMCSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, keep_forest, pre_initialized, true); } else if (model_type == StochTree::ModelType::kMultivariateRegressionLeafGaussian) { - StochTree::MCMCSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, keep_forest, pre_initialized, true, num_threads, num_basis); + StochTree::MCMCSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, keep_forest, pre_initialized, true, num_basis); } else if (model_type == StochTree::ModelType::kLogLinearVariance) { - StochTree::MCMCSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, keep_forest, pre_initialized, false, num_threads); + StochTree::MCMCSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, keep_forest, pre_initialized, false); } } } diff --git a/src/sampler.cpp b/src/sampler.cpp index ee8bd6e6..255f6e7c 100644 --- a/src/sampler.cpp +++ b/src/sampler.cpp @@ -4,39 +4,41 @@ #include #include #include +#include #include #include #include -#include +#include #include +#include +#include [[cpp11::register]] -void sample_gfr_one_iteration_cpp(cpp11::external_pointer data, - cpp11::external_pointer residual, - cpp11::external_pointer forest_samples, - cpp11::external_pointer active_forest, - cpp11::external_pointer tracker, - cpp11::external_pointer split_prior, - cpp11::external_pointer rng, - cpp11::integers sweep_indices, - cpp11::integers feature_types, int cutpoint_grid_size, - cpp11::doubles_matrix<> leaf_model_scale_input, - cpp11::doubles variable_weights, +void sample_gfr_one_iteration_cpp(cpp11::external_pointer data, + cpp11::external_pointer residual, + cpp11::external_pointer forest_samples, + cpp11::external_pointer active_forest, + cpp11::external_pointer tracker, + cpp11::external_pointer split_prior, + cpp11::external_pointer rng, + cpp11::integers sweep_indices, + cpp11::integers feature_types, int cutpoint_grid_size, + cpp11::doubles_matrix<> leaf_model_scale_input, + cpp11::doubles variable_weights, double a_forest, double b_forest, - double global_variance, int leaf_model_int, - bool keep_forest, int num_features_subsample, - int num_threads + double global_variance, int leaf_model_int, + bool keep_forest, int num_features_subsample ) { // Refactoring completely out of the R interface. // Intention to refactor out of the C++ interface in the future. bool pre_initialized = true; - + // Unpack feature types std::vector feature_types_(feature_types.size()); for (int i = 0; i < feature_types.size(); i++) { feature_types_[i] = static_cast(feature_types[i]); } - + // Unpack sweep indices std::vector sweep_indices_(sweep_indices.size()); // if (sweep_indices.size() > 0) { @@ -45,19 +47,20 @@ void sample_gfr_one_iteration_cpp(cpp11::external_pointer var_weights_vector(variable_weights.size()); for (int i = 0; i < variable_weights.size(); i++) { var_weights_vector[i] = variable_weights[i]; } - + // Prepare the samplers StochTree::LeafModelVariant leaf_model = StochTree::leafModelFactory(model_type, leaf_scale, leaf_scale_matrix, a_forest, b_forest); int num_basis = data->NumBasis(); - + // Run one iteration of the sampler if (model_type == StochTree::ModelType::kConstantLeafGaussian) { - StochTree::GFRSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_features_subsample, num_threads); + StochTree::GFRSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_features_subsample); } else if (model_type == StochTree::ModelType::kUnivariateRegressionLeafGaussian) { - StochTree::GFRSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_features_subsample, num_threads); + StochTree::GFRSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_features_subsample); } else if (model_type == StochTree::ModelType::kMultivariateRegressionLeafGaussian) { - StochTree::GFRSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_features_subsample, num_threads, num_basis); + StochTree::GFRSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_features_subsample, num_basis); } else if (model_type == StochTree::ModelType::kLogLinearVariance) { - StochTree::GFRSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, false, num_features_subsample, num_threads); + StochTree::GFRSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, false, num_features_subsample); } } [[cpp11::register]] -void sample_mcmc_one_iteration_cpp(cpp11::external_pointer data, - cpp11::external_pointer residual, - cpp11::external_pointer forest_samples, - cpp11::external_pointer active_forest, - cpp11::external_pointer tracker, - cpp11::external_pointer split_prior, - cpp11::external_pointer rng, - cpp11::integers sweep_indices, - cpp11::integers feature_types, int cutpoint_grid_size, - cpp11::doubles_matrix<> leaf_model_scale_input, - cpp11::doubles variable_weights, +void sample_mcmc_one_iteration_cpp(cpp11::external_pointer data, + cpp11::external_pointer residual, + cpp11::external_pointer forest_samples, + cpp11::external_pointer active_forest, + cpp11::external_pointer tracker, + cpp11::external_pointer split_prior, + cpp11::external_pointer rng, + cpp11::integers sweep_indices, + cpp11::integers feature_types, int cutpoint_grid_size, + cpp11::doubles_matrix<> leaf_model_scale_input, + cpp11::doubles variable_weights, double a_forest, double b_forest, - double global_variance, int leaf_model_int, - bool keep_forest, int num_threads + double global_variance, int leaf_model_int, + bool keep_forest ) { // Refactoring completely out of the R interface. // Intention to refactor out of the C++ interface in the future. bool pre_initialized = true; - + // Unpack feature types std::vector feature_types_(feature_types.size()); for (int i = 0; i < feature_types.size(); i++) { feature_types_[i] = static_cast(feature_types[i]); } - + // Unpack sweep indices std::vector sweep_indices_; if (sweep_indices.size() > 0) { @@ -127,19 +130,20 @@ void sample_mcmc_one_iteration_cpp(cpp11::external_pointer var_weights_vector(variable_weights.size()); for (int i = 0; i < variable_weights.size(); i++) { var_weights_vector[i] = variable_weights[i]; } - + // Prepare the samplers StochTree::LeafModelVariant leaf_model = StochTree::leafModelFactory(model_type, leaf_scale, leaf_scale_matrix, a_forest, b_forest); int num_basis = data->NumBasis(); - + // Run one iteration of the sampler if (model_type == StochTree::ModelType::kConstantLeafGaussian) { - StochTree::MCMCSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, keep_forest, pre_initialized, true, num_threads); + StochTree::MCMCSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, keep_forest, pre_initialized, true); } else if (model_type == StochTree::ModelType::kUnivariateRegressionLeafGaussian) { - StochTree::MCMCSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, keep_forest, pre_initialized, true, num_threads); + StochTree::MCMCSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, keep_forest, pre_initialized, true); } else if (model_type == StochTree::ModelType::kMultivariateRegressionLeafGaussian) { - StochTree::MCMCSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, keep_forest, pre_initialized, true, num_threads, num_basis); + StochTree::MCMCSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, keep_forest, pre_initialized, true, num_basis); } else if (model_type == StochTree::ModelType::kLogLinearVariance) { - StochTree::MCMCSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, keep_forest, pre_initialized, false, num_threads); + StochTree::MCMCSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, keep_forest, pre_initialized, false); + } else if (model_type == StochTree::ModelType::kCloglogOrdinal) { + StochTree::MCMCSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, keep_forest, pre_initialized, false); } } [[cpp11::register]] -double sample_sigma2_one_iteration_cpp(cpp11::external_pointer residual, - cpp11::external_pointer dataset, - cpp11::external_pointer rng, +double sample_sigma2_one_iteration_cpp(cpp11::external_pointer residual, + cpp11::external_pointer dataset, + cpp11::external_pointer rng, double a, double b ) { // Run one iteration of the sampler @@ -191,8 +197,8 @@ double sample_sigma2_one_iteration_cpp(cpp11::external_pointer active_forest, - cpp11::external_pointer rng, +double sample_tau_one_iteration_cpp(cpp11::external_pointer active_forest, + cpp11::external_pointer rng, double a, double b ) { // Run one iteration of the sampler @@ -209,7 +215,7 @@ cpp11::external_pointer rng_cpp(int random_seed = -1) { } else { rng_ = std::make_unique(random_seed); } - + // Release management of the pointer to R session return cpp11::external_pointer(rng_.release()); } @@ -218,7 +224,7 @@ cpp11::external_pointer rng_cpp(int random_seed = -1) { cpp11::external_pointer tree_prior_cpp(double alpha, double beta, int min_samples_leaf, int max_depth = -1) { // Create smart pointer to newly allocated object std::unique_ptr prior_ptr_ = std::make_unique(alpha, beta, min_samples_leaf, max_depth); - + // Release management of the pointer to R session return cpp11::external_pointer(prior_ptr_.release()); } @@ -275,10 +281,10 @@ cpp11::external_pointer forest_tracker_cpp(cpp11::exte for (int i = 0; i < feature_types.size(); i++) { feature_types_[i] = static_cast(feature_types[i]); } - + // Create smart pointer to newly allocated object std::unique_ptr tracker_ptr_ = std::make_unique(data->GetCovariates(), feature_types_, num_trees, n); - + // Release management of the pointer to R session return cpp11::external_pointer(tracker_ptr_.release()); } @@ -295,8 +301,8 @@ cpp11::writable::doubles get_cached_forest_predictions_cpp(cpp11::external_point [[cpp11::register]] cpp11::writable::integers sample_without_replacement_integer_cpp( - cpp11::integers population_vector, - cpp11::doubles sampling_probs, + cpp11::integers population_vector, + cpp11::doubles sampling_probs, int sample_size ) { // Unpack pointer to population vector @@ -308,14 +314,14 @@ cpp11::writable::integers sample_without_replacement_integer_cpp( // Create output vector cpp11::writable::integers output(sample_size); - + // Unpack pointer to output vector int* output_ptr = INTEGER(PROTECT(output)); // Create C++ RNG std::random_device rd; std::mt19937 gen(rd()); - + // Run the sampler StochTree::sample_without_replacement( output_ptr, sampling_probs_ptr, population_vector_ptr, population_size, sample_size, gen @@ -372,8 +378,8 @@ void ordinal_aux_data_set_vector_cpp(cpp11::external_pointer tracker_ptr) { // Get auxiliary data vectors const std::vector& gamma = tracker_ptr->GetOrdinalAuxDataVector(2); // cutpoints gamma - std::vector& seg = tracker_ptr->GetOrdinalAuxDataVector(3); // cumsum exp gamma - + std::vector& seg = tracker_ptr->GetOrdinalAuxDataVector(3); // cumsum exp gamma + // Update seg (cumulative sum of exp(gamma)) for (size_t j = 0; j < seg.size(); j++) { if (j == 0) { @@ -397,7 +403,7 @@ cpp11::external_pointer ordinal_sampler_cpp() { [[cpp11::register]] void ordinal_sampler_update_latent_variables_cpp( cpp11::external_pointer sampler_ptr, - cpp11::external_pointer data_ptr, + cpp11::external_pointer data_ptr, cpp11::external_pointer outcome_ptr, cpp11::external_pointer tracker_ptr, cpp11::external_pointer rng_ptr @@ -427,3 +433,4 @@ void ordinal_sampler_update_cumsum_exp_cpp( sampler_ptr->UpdateCumulativeExpSums(*tracker_ptr); } + diff --git a/src/serialization.cpp b/src/serialization.cpp index fb248f62..749395e8 100644 --- a/src/serialization.cpp +++ b/src/serialization.cpp @@ -8,6 +8,9 @@ #include #include #include +#include +#include +#include [[cpp11::register]] cpp11::external_pointer init_json_cpp() { diff --git a/src/stochtree_types.h b/src/stochtree_types.h index d3d6327c..9f4e77df 100644 --- a/src/stochtree_types.h +++ b/src/stochtree_types.h @@ -1,8 +1,10 @@ #include #include #include +#include #include #include +#include #include #include #include diff --git a/src/tree.cpp b/src/tree.cpp index 32c51475..fa6fd8f8 100644 --- a/src/tree.cpp +++ b/src/tree.cpp @@ -8,6 +8,9 @@ #include #include +#include +#include +#include namespace StochTree { @@ -665,6 +668,7 @@ void Tree::from_json(const json& tree_json) { tree_json.at("has_categorical_split").get_to(this->has_categorical_split_); tree_json.at("output_dimension").get_to(this->output_dimension_); tree_json.at("is_log_scale").get_to(this->is_log_scale_); + this->num_deleted_nodes = 0; // Unpack the array based fields JsonToTreeNodeVectors(tree_json, this); From c8492fb7650f7eccbb4e79fcc7b5797c5dbece80 Mon Sep 17 00:00:00 2001 From: Entejar Alam Date: Sun, 28 Sep 2025 06:24:48 -0500 Subject: [PATCH 7/9] =?UTF-8?q?Tested=20CLogLog=20Ordinal=20BART=20?= =?UTF-8?q?=E2=80=94=20running=20successfully!?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tools/debug/testing_cloglog_ordinal_bart.R | 111 +++++++++++++++++++++ 1 file changed, 111 insertions(+) create mode 100644 tools/debug/testing_cloglog_ordinal_bart.R diff --git a/tools/debug/testing_cloglog_ordinal_bart.R b/tools/debug/testing_cloglog_ordinal_bart.R new file mode 100644 index 00000000..78f8d61c --- /dev/null +++ b/tools/debug/testing_cloglog_ordinal_bart.R @@ -0,0 +1,111 @@ +# Simulate ordinal data and run Cloglog Ordinal BART + +# Load +library(stochtree) + +set.seed(2025) + +# Simulation +n_samples <- 2000 +n_features <- 5 +n_categories <- 3 + +X <- matrix(rnorm(n_samples * n_features), n_samples, n_features) + +beta <- rep(1 / sqrt(n_features), n_features) +gamma_true <- c(-2, 1) + +linear_predictor <- X %*% beta + +# Transform linear predictor using the complementary log-log link function +p_0 <- 1 - exp(-exp(gamma_true[1] + linear_predictor)) +p_1 <- exp(-exp(gamma_true[1] + linear_predictor)) * + (1 - exp(-exp(gamma_true[2] + linear_predictor))) +p_2 <- exp(-exp(gamma_true[1] + linear_predictor)) * + exp(-exp(gamma_true[2] + linear_predictor)) + +true_probs <- cbind(p_0, p_1, p_2) + +# Get Outcomes +ordinal_outcome <- sapply(1:nrow(X), function(i) sample(1:n_categories, 1, prob = true_probs[i, ])) +cat("Outcome distribution:", table(ordinal_outcome), "\n") + +train_index <- 1:(n_samples/2) +test_index <- (1:n_samples)[- train_index] + +X_train <- X[train_index, ] +y_train <- ordinal_outcome[train_index] +X_test <- X[-train_index, ] +y_test <- ordinal_outcome[-train_index] + +out <- cloglog_ordinal_bart( + X = X_train, + y = y_train, + X_test = X_test, + n_samples_mcmc = 1000, + n_burnin = 500, + n_thin = 1 +) + + +# Inference and diagnostics +par(mfrow = c(2, 1)) +plot(out$gamma_samples[1, ], type = 'l', main = expression(gamma[1]), ylab = "Value", xlab = "MCMC Sample") +abline(h = gamma_true[1], col = 'red', lty = 2) +plot(out$gamma_samples[2, ], type = 'l', main = expression(gamma[2]), ylab = "Value", xlab = "MCMC Sample") +abline(h = gamma_true[2], col = 'red', lty = 2) + +gamma1 <- out$gamma_samples[1,] + colMeans(out$forest_predictions_train) +summary(gamma1) +hist(gamma1) + +gamma2 <- out$gamma_samples[2,] + colMeans(out$forest_predictions_train) +summary(gamma2) +hist(gamma2) + +par(mfrow = c(3,2), mar = c(5,4,1,1)) +rowMeans(out$gamma_samples) +moo <- t(out$gamma_samples) + colMeans(out$forest_predictions_train) +plot(moo[,1]) +abline(h = gamma_true[1] + mean(linear_predictor[train_index])) +plot(moo[,2]) +abline(h = gamma_true[2] + mean(linear_predictor[train_index])) +plot(out$gamma_samples[1,]) +plot(out$gamma_samples[2,]) + +# Compare forest predictions with the truth + +plot(rowMeans(out$forest_predictions_train) - mean(out$forest_predictions_train), linear_predictor[train_index]) +abline(a=0,b=1,col='blue', lwd=2) + +plot(rowMeans(out$forest_predictions_test) - mean(out$forest_predictions_test), linear_predictor[test_index]) +abline(a=0,b=1,col='blue', lwd=2) + +# Train set ordinal class probabilities + +p_hat_0 <- rowMeans(1 - exp(-exp(out$forest_predictions_train + out$gamma_samples[1, ]))) +p_hat_1 <- rowMeans((1 - exp(-exp(out$forest_predictions_train + out$gamma_samples[2,]))) * exp(-exp(out$forest_predictions_train + out$gamma_samples[1,]))) +p_hat_2 <- 1 - p_hat_1 - p_hat_0 + +mean(log(-log(1 - p_hat_0)) - rowMeans(out$forest_predictions_train)) + +plot(p_hat_0, p_0[train_index]) +abline(a=0,b=1,col='blue', lwd=2) +plot(p_hat_1, p_1[train_index]) +abline(a=0,b=1,col='blue', lwd=2) +plot(p_hat_2, p_2[train_index]) +abline(a=0,b=1,col='blue', lwd=2) + +# Test set ordinal class probabilities + +p_hat_0 <- rowMeans(1 - exp(-exp(out$forest_predictions_test + out$gamma_samples[1, ]))) +p_hat_1 <- rowMeans((1 - exp(-exp(out$forest_predictions_test + out$gamma_samples[2,]))) * exp(-exp(out$forest_predictions_test + out$gamma_samples[1,]))) +p_hat_2 <- 1 - p_hat_1 - p_hat_0 + +plot(p_hat_0, p_0[test_index]) +abline(a=0,b=1,col='blue', lwd=2) +plot(p_hat_1, p_1[test_index]) +abline(a=0,b=1,col='blue', lwd=2) +plot(p_hat_2, p_2[test_index]) +abline(a=0,b=1,col='blue', lwd=2) + From 444c0674e8d8d538e044ecddcc7deea4570c3d55 Mon Sep 17 00:00:00 2001 From: Entejar Alam Date: Sun, 28 Sep 2025 19:37:54 -0500 Subject: [PATCH 8/9] Added vignette for CLogLog Ordinal Bart --- NAMESPACE | 1 + R/cloglog_ordinal_bart.R | 1 + tools/debug/testing_cloglog_ordinal_bart.R | 159 +++++++++++-------- vignettes/CLogLogOrdinalBart.Rmd | 173 +++++++++++++++++++++ vignettes/vignettes.bib | 9 +- 5 files changed, 278 insertions(+), 65 deletions(-) create mode 100644 vignettes/CLogLogOrdinalBart.Rmd diff --git a/NAMESPACE b/NAMESPACE index 2f4103c0..a4062f5e 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -7,6 +7,7 @@ S3method(predict,bcfmodel) export(bart) export(bcf) export(calibrateInverseGammaErrorVariance) +export(cloglog_ordinal_bart) export(computeForestLeafIndices) export(computeForestLeafVariances) export(computeForestMaxLeafIndex) diff --git a/R/cloglog_ordinal_bart.R b/R/cloglog_ordinal_bart.R index 9cc9b63a..a8117c77 100644 --- a/R/cloglog_ordinal_bart.R +++ b/R/cloglog_ordinal_bart.R @@ -11,6 +11,7 @@ #' @param beta_gamma Rate parameter for the log-gamma prior on cutpoints. Default: `2.0`. #' @param variable_weights Optional vector of variable weights for splitting (default: equal weights). #' @param feature_types Optional vector indicating feature types (0 for continuous, 1 for categorical; default: all continuous). +#' @export cloglog_ordinal_bart <- function(X, y, X_test = NULL, diff --git a/tools/debug/testing_cloglog_ordinal_bart.R b/tools/debug/testing_cloglog_ordinal_bart.R index 78f8d61c..71ef790a 100644 --- a/tools/debug/testing_cloglog_ordinal_bart.R +++ b/tools/debug/testing_cloglog_ordinal_bart.R @@ -5,38 +5,47 @@ library(stochtree) set.seed(2025) -# Simulation -n_samples <- 2000 -n_features <- 5 -n_categories <- 3 - -X <- matrix(rnorm(n_samples * n_features), n_samples, n_features) - -beta <- rep(1 / sqrt(n_features), n_features) -gamma_true <- c(-2, 1) - -linear_predictor <- X %*% beta - -# Transform linear predictor using the complementary log-log link function -p_0 <- 1 - exp(-exp(gamma_true[1] + linear_predictor)) -p_1 <- exp(-exp(gamma_true[1] + linear_predictor)) * - (1 - exp(-exp(gamma_true[2] + linear_predictor))) -p_2 <- exp(-exp(gamma_true[1] + linear_predictor)) * - exp(-exp(gamma_true[2] + linear_predictor)) +# Sample size and number of predictors +n <- 2000 +p <- 5 -true_probs <- cbind(p_0, p_1, p_2) +# Design matrix and true lambda function +X <- matrix(rnorm(n * p), n, p) +beta <- rep(1 / sqrt(p), p) +true_lambda_function <- X %*% beta -# Get Outcomes -ordinal_outcome <- sapply(1:nrow(X), function(i) sample(1:n_categories, 1, prob = true_probs[i, ])) -cat("Outcome distribution:", table(ordinal_outcome), "\n") -train_index <- 1:(n_samples/2) -test_index <- (1:n_samples)[- train_index] - -X_train <- X[train_index, ] -y_train <- ordinal_outcome[train_index] -X_test <- X[-train_index, ] -y_test <- ordinal_outcome[-train_index] +# Set cutpoints for ordinal categories (3 categories: 1, 2, 3) +n_categories <- 3 +gamma_true <- c(-2, 1) +ordinal_cutpoints <- log(cumsum(exp(gamma_true))) +ordinal_cutpoints + +# True ordinal class probabilities +true_probs <- matrix(0, nrow = n, ncol = n_categories) +for (j in 1:n_categories) { + if (j == 1) { + true_probs[, j] <- 1 - exp(-exp(gamma_true[j] + true_lambda_function)) + } else if (j == n_categories) { + true_probs[, j] <- 1 - rowSums(true_probs[, 1:(j - 1), drop = FALSE]) + } else { + true_probs[, j] <- exp(-exp(gamma_true[j - 1] + true_lambda_function)) * + (1 - exp(-exp(gamma_true[j] + true_lambda_function))) + } +} + +# Generate ordinal outcomes +y <- sapply(1:nrow(X), function(i) sample(1:n_categories, 1, prob = true_probs[i, ])) +cat("Outcome distribution:", table(y), "\n") + +# CLogLog Ordinal BART model fitting +train_idx <- sample(1:n, size = floor(0.8 * n)) +test_idx <- setdiff(1:n, train_idx) + +X_train <- X[train_idx, ] +y_train <- y[train_idx] +X_test <- X[test_idx, ] +y_test <- y[test_idx] out <- cloglog_ordinal_bart( X = X_train, @@ -67,45 +76,67 @@ par(mfrow = c(3,2), mar = c(5,4,1,1)) rowMeans(out$gamma_samples) moo <- t(out$gamma_samples) + colMeans(out$forest_predictions_train) plot(moo[,1]) -abline(h = gamma_true[1] + mean(linear_predictor[train_index])) +abline(h = gamma_true[1] + mean(true_lambda_function[train_idx])) plot(moo[,2]) -abline(h = gamma_true[2] + mean(linear_predictor[train_index])) +abline(h = gamma_true[2] + mean(true_lambda_function[train_idx])) plot(out$gamma_samples[1,]) plot(out$gamma_samples[2,]) -# Compare forest predictions with the truth - -plot(rowMeans(out$forest_predictions_train) - mean(out$forest_predictions_train), linear_predictor[train_index]) -abline(a=0,b=1,col='blue', lwd=2) - -plot(rowMeans(out$forest_predictions_test) - mean(out$forest_predictions_test), linear_predictor[test_index]) +# Compare forest predictions with the truth function (for training and test sets) +lambda_pred_train <- rowMeans(out$forest_predictions_train) - mean(out$forest_predictions_train) +plot(lambda_pred_train, true_lambda_function[train_idx]) abline(a=0,b=1,col='blue', lwd=2) +cor_train <- cor(true_lambda_function[train_idx], lambda_pred_train) +text(min(true_lambda_function[train_idx]), max(true_lambda_function[train_idx]), paste('Correlation:', round(cor_train, 3)), adj = 0, col = 'red') -# Train set ordinal class probabilities - -p_hat_0 <- rowMeans(1 - exp(-exp(out$forest_predictions_train + out$gamma_samples[1, ]))) -p_hat_1 <- rowMeans((1 - exp(-exp(out$forest_predictions_train + out$gamma_samples[2,]))) * exp(-exp(out$forest_predictions_train + out$gamma_samples[1,]))) -p_hat_2 <- 1 - p_hat_1 - p_hat_0 - -mean(log(-log(1 - p_hat_0)) - rowMeans(out$forest_predictions_train)) - -plot(p_hat_0, p_0[train_index]) -abline(a=0,b=1,col='blue', lwd=2) -plot(p_hat_1, p_1[train_index]) -abline(a=0,b=1,col='blue', lwd=2) -plot(p_hat_2, p_2[train_index]) -abline(a=0,b=1,col='blue', lwd=2) - -# Test set ordinal class probabilities - -p_hat_0 <- rowMeans(1 - exp(-exp(out$forest_predictions_test + out$gamma_samples[1, ]))) -p_hat_1 <- rowMeans((1 - exp(-exp(out$forest_predictions_test + out$gamma_samples[2,]))) * exp(-exp(out$forest_predictions_test + out$gamma_samples[1,]))) -p_hat_2 <- 1 - p_hat_1 - p_hat_0 - -plot(p_hat_0, p_0[test_index]) +lambda_pred_test <- rowMeans(out$forest_predictions_test) - mean(out$forest_predictions_test) +plot(lambda_pred_test, true_lambda_function[test_idx]) abline(a=0,b=1,col='blue', lwd=2) -plot(p_hat_1, p_1[test_index]) -abline(a=0,b=1,col='blue', lwd=2) -plot(p_hat_2, p_2[test_index]) -abline(a=0,b=1,col='blue', lwd=2) - +cor_test <- cor(true_lambda_function[test_idx], lambda_pred_test) +text(min(true_lambda_function[test_idx]), max(true_lambda_function[test_idx]), paste('Correlation:', round(cor_test, 3)), adj = 0, col = 'red') + +# Estimated ordinal class probabilities for the training set +est_probs_train <- matrix(0, nrow=length(train_idx), ncol=n_categories) +for (j in 1:n_categories) { + if (j == 1) { + est_probs_train[, j] <- rowMeans(1 - exp(-exp(out$forest_predictions_train + out$gamma_samples[j, ]))) + } else if (j == n_categories) { + est_probs_train[, j] <- 1 - rowSums(est_probs_train[, 1:(j - 1), drop = FALSE]) + } else { + est_probs_train[, j] <- rowMeans(exp(-exp(out$forest_predictions_train + out$gamma_samples[j-1,])) * + (1 - exp(-exp(out$forest_predictions_train + out$gamma_samples[j,])))) + } +} + +mean(log(-log(1 - est_probs_train[, 1])) - rowMeans(out$forest_predictions_train)) + +# Compare estimated vs true class probabilities for training set +for (j in 1:n_categories) { + plot(true_probs[train_idx, j], est_probs_train[, j], xlab = paste("True Prob Category", j), ylab = paste("Estimated Prob Category", j)) + abline(a = 0, b = 1, col = 'blue', lwd = 2) + cor_train_prob <- cor(true_probs[train_idx, j], est_probs_train[, j]) + text(min(true_probs[train_idx, j]), max(est_probs_train[, j]), paste('Correlation:', round(cor_train_prob, 3)), adj = 0, col = 'red') +} + +# Estimated ordinal class probabilities for the test set +est_probs_test <- matrix(0, nrow=length(test_idx), ncol=n_categories) +for (j in 1:n_categories) { + if (j == 1) { + est_probs_test[, j] <- rowMeans(1 - exp(-exp(out$forest_predictions_test + out$gamma_samples[j, ]))) + } else if (j == n_categories) { + est_probs_test[, j] <- 1 - rowSums(est_probs_test[, 1:(j - 1), drop = FALSE]) + } else { + est_probs_test[, j] <- rowMeans(exp(-exp(out$forest_predictions_test + out$gamma_samples[j-1,])) * + (1 - exp(-exp(out$forest_predictions_test + out$gamma_samples[j,])))) + } +} + +mean(log(-log(1 - est_probs_test[, 1])) - rowMeans(out$forest_predictions_test)) + +# Compare estimated vs true class probabilities for test set +for (j in 1:n_categories) { + plot(true_probs[test_idx, j], est_probs_test[, j], xlab = paste("True Prob Category", j), ylab = paste("Estimated Prob Category", j)) + abline(a = 0, b = 1, col = 'blue', lwd = 2) + cor_test_prob <- cor(true_probs[test_idx, j], est_probs_test[, j]) + text(min(true_probs[test_idx, j]), max(est_probs_test[, j]), paste('Correlation:', round(cor_test_prob, 3)), adj = 0, col = 'red') +} diff --git a/vignettes/CLogLogOrdinalBart.Rmd b/vignettes/CLogLogOrdinalBart.Rmd new file mode 100644 index 00000000..a87b1ebb --- /dev/null +++ b/vignettes/CLogLogOrdinalBart.Rmd @@ -0,0 +1,173 @@ +--- +title: "Complementary Log-Log Ordinal BART in StochTree" +output: rmarkdown::html_vignette +vignette: > + %\VignetteIndexEntry{CLogLog-Ordinal-BART} + %\VignetteEncoding{UTF-8} + %\VignetteEngine{knitr::rmarkdown} +bibliography: vignettes.bib +editor_options: + markdown: + wrap: 72 +--- + +```{r, include = FALSE} +knitr::opts_chunk$set( + collapse = TRUE, + comment = "#>" +) +``` + +This vignette demonstrates how to use the `cloglog_ordinal_bart()` function for modeling ordinal outcomes using a complementary log-log link function in the BART (Bayesian Additive Regression Trees) framework. + +To begin, we load the `stochtree` package. + +```{r setup} +library(stochtree) +``` + +# Introduction to Ordinal BART with CLogLog Link + +Ordinal data represents outcomes that have a natural ordering but undefined distances between categories. Examples include survey responses (strongly disagree, disagree, neutral, agree, strongly agree), severity ratings (mild, moderate, severe), or educational levels (elementary, high school, college, graduate). + +The complementary log-log (CLogLog) model uses the link function: +$$\text{cloglog}(p) = \log(-\log(1-p))$$ + +This link function is asymmetric and particularly appropriate when the probability of being in higher categories changes rapidly at certain thresholds, making it different from the symmetric probit or logit links commonly used in ordinal regression. + +In the BART framework with CLogLog ordinal regression, we model: +$$P(Y = k \mid Y \geq k, X = x) = 1 - \exp\left(-e^{\gamma_k + \lambda(x)}\right)$$ + +where $\lambda(x)$ is learned by the BART ensemble and $c_k = \log \sum_{j \leq k}e^{\gamma_j}$ are the cutpoints for the ordinal categories. + +## Data Simulation + +```{r demo1-simulation} +set.seed(2025) +# Sample size and number of predictors +n <- 2000 +p <- 5 + +# Design matrix and true lambda function +X <- matrix(rnorm(n * p), n, p) +beta <- rep(1 / sqrt(p), p) +true_lambda_function <- X %*% beta + + +# Set cutpoints for ordinal categories (3 categories: 1, 2, 3) +n_categories <- 3 +gamma_true <- c(-2, 1) +ordinal_cutpoints <- log(cumsum(exp(gamma_true))) +ordinal_cutpoints + +# True ordinal class probabilities +true_probs <- matrix(0, nrow = n, ncol = n_categories) +for (j in 1:n_categories) { + if (j == 1) { + true_probs[, j] <- 1 - exp(-exp(gamma_true[j] + true_lambda_function)) + } else if (j == n_categories) { + true_probs[, j] <- 1 - rowSums(true_probs[, 1:(j - 1), drop = FALSE]) + } else { + true_probs[, j] <- exp(-exp(gamma_true[j - 1] + true_lambda_function)) * + (1 - exp(-exp(gamma_true[j] + true_lambda_function))) + } +} + +# Generate ordinal outcomes +y <- sapply(1:nrow(X), function(i) sample(1:n_categories, 1, prob = true_probs[i, ])) +cat("Outcome distribution:", table(y), "\n") +``` + +## Model Fitting + +Now let's fit the CLogLog Ordinal BART model: + +```{r demo1-model-fitting} +# Split data into train and test sets +train_idx <- sample(1:n, size = floor(0.8 * n)) +test_idx <- setdiff(1:n, train_idx) + +X_train <- X[train_idx, ] +y_train <- y[train_idx] +X_test <- X[test_idx, ] +y_test <- y[test_idx] + +# Fit CLogLog Ordinal BART model +out <- stochtree::cloglog_ordinal_bart( + X = X_train, + y = y_train, + X_test = X_test, + n_samples_mcmc = 1000, + n_burnin = 500, + n_thin = 1 +) +``` + +## Model Results and Interpretation + +Let's examine the posterior samples and model performance: + +```{r demo1-results} +# Compare forest predictions with the truth function (for training and test sets) +lambda_pred_train <- rowMeans(out$forest_predictions_train) - mean(out$forest_predictions_train) +plot(lambda_pred_train, true_lambda_function[train_idx]) +abline(a=0,b=1,col='blue', lwd=2) +cor_train <- cor(true_lambda_function[train_idx], lambda_pred_train) +text(min(true_lambda_function[train_idx]), max(true_lambda_function[train_idx]), paste('Correlation:', round(cor_train, 3)), adj = 0, col = 'red') + +lambda_pred_test <- rowMeans(out$forest_predictions_test) - mean(out$forest_predictions_test) +plot(lambda_pred_test, true_lambda_function[test_idx]) +abline(a=0,b=1,col='blue', lwd=2) +cor_test <- cor(true_lambda_function[test_idx], lambda_pred_test) +text(min(true_lambda_function[test_idx]), max(true_lambda_function[test_idx]), paste('Correlation:', round(cor_test, 3)), adj = 0, col = 'red') + +# Estimated ordinal class probabilities for the training set +est_probs_train <- matrix(0, nrow=length(train_idx), ncol=n_categories) +for (j in 1:n_categories) { + if (j == 1) { + est_probs_train[, j] <- rowMeans(1 - exp(-exp(out$forest_predictions_train + out$gamma_samples[j, ]))) + } else if (j == n_categories) { + est_probs_train[, j] <- 1 - rowSums(est_probs_train[, 1:(j - 1), drop = FALSE]) + } else { + est_probs_train[, j] <- rowMeans(exp(-exp(out$forest_predictions_train + out$gamma_samples[j-1,])) * + (1 - exp(-exp(out$forest_predictions_train + out$gamma_samples[j,])))) + } +} + +# Compare estimated vs true class probabilities for training set +for (j in 1:n_categories) { + plot(true_probs[train_idx, j], est_probs_train[, j], xlab = paste("True Prob Category", j), ylab = paste("Estimated Prob Category", j)) + abline(a = 0, b = 1, col = 'blue', lwd = 2) + cor_train_prob <- cor(true_probs[train_idx, j], est_probs_train[, j]) + text(min(true_probs[train_idx, j]), max(est_probs_train[, j]), paste('Correlation:', round(cor_train_prob, 3)), adj = 0, col = 'red') +} + +# Estimated ordinal class probabilities for the test set +est_probs_test <- matrix(0, nrow=length(test_idx), ncol=n_categories) +for (j in 1:n_categories) { + if (j == 1) { + est_probs_test[, j] <- rowMeans(1 - exp(-exp(out$forest_predictions_test + out$gamma_samples[j, ]))) + } else if (j == n_categories) { + est_probs_test[, j] <- 1 - rowSums(est_probs_test[, 1:(j - 1), drop = FALSE]) + } else { + est_probs_test[, j] <- rowMeans(exp(-exp(out$forest_predictions_test + out$gamma_samples[j-1,])) * + (1 - exp(-exp(out$forest_predictions_test + out$gamma_samples[j,])))) + } +} + +# Compare estimated vs true class probabilities for test set +for (j in 1:n_categories) { + plot(true_probs[test_idx, j], est_probs_test[, j], xlab = paste("True Prob Category", j), ylab = paste("Estimated Prob Category", j)) + abline(a = 0, b = 1, col = 'blue', lwd = 2) + cor_test_prob <- cor(true_probs[test_idx, j], est_probs_test[, j]) + text(min(true_probs[test_idx, j]), max(est_probs_test[, j]), paste('Correlation:', round(cor_test_prob, 3)), adj = 0, col = 'red') +} +``` + +# Conclusion + +The CLogLog Ordinal BART model in `stochtree` provides a flexible and powerful approach for modeling ordinal outcomes, especially better suited for asymmetric outcomes: Rare events (e.g., credit default, fraud detection, system failures, adverse drug reactions), Toxic thresholds (e.g., credit risk escalation, dose-response toxicity, engagement drop-offs), Discrete survival outcomes (e.g., time-to-default, customer churn, progression-free survival). + +The CLogLog Ordinal BART implementation in `stochtree` builds on the paper by @alam2025unified. + +# References diff --git a/vignettes/vignettes.bib b/vignettes/vignettes.bib index a1b0a768..65a6f152 100644 --- a/vignettes/vignettes.bib +++ b/vignettes/vignettes.bib @@ -117,4 +117,11 @@ @book{scholkopf2002learning author={Sch{\"o}lkopf, Bernhard and Smola, Alexander J}, year={2002}, publisher={MIT press} -} \ No newline at end of file +} + +@article{alam2025unified, + title={A Unified Bayesian Nonparametric Framework for Ordinal, Survival, and Density Regression Using the Complementary Log-Log Link}, + author={Alam, Entejar and Linero, Antonio R}, + journal={arXiv preprint arXiv:2502.00606}, + year={2025} +} From 132071ebefe99cb657c3a9a95cc10d5609c0fd19 Mon Sep 17 00:00:00 2001 From: Entejar Alam Date: Tue, 30 Sep 2025 15:11:46 -0500 Subject: [PATCH 9/9] Update leaf_model.h --- include/stochtree/leaf_model.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/stochtree/leaf_model.h b/include/stochtree/leaf_model.h index 6adf9c23..02bf4d16 100644 --- a/include/stochtree/leaf_model.h +++ b/include/stochtree/leaf_model.h @@ -403,7 +403,7 @@ class GaussianConstantSuffStat { sum_w = 0.0; sum_yw = 0.0; } - /*! + /*! * \brief Increment the value of each sufficient statistic by the values provided by `suff_stat` * * \param suff_stat Sufficient statistic to be added to the current sufficient statistics