Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ S3method(predict,bcfmodel)
export(bart)
export(bcf)
export(calibrateInverseGammaErrorVariance)
export(cloglog_ordinal_bart)
export(computeForestLeafIndices)
export(computeForestLeafVariances)
export(computeForestMaxLeafIndex)
Expand Down
181 changes: 181 additions & 0 deletions R/cloglog_ordinal_bart.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
#' Run the BART algorithm for ordinal outcomes using a complementary log-log link
#'
#' @param X A numeric matrix of predictors (training data).
#' @param y A numeric vector of ordinal outcomes (positive integers starting from 1).
#' @param X_test An optional numeric matrix of predictors (test data).
#' @param n_trees Number of trees in the BART ensemble. Default: `50`.
#' @param n_samples_mcmc Total number of MCMC samples to draw. Default: `500`.
#' @param n_burnin Number of burn-in samples to discard. Default: `250`.
#' @param n_thin Thinning interval for MCMC samples. Default: `1`.
#' @param alpha_gamma Shape parameter for the log-gamma prior on cutpoints. Default: `2.0`.
#' @param beta_gamma Rate parameter for the log-gamma prior on cutpoints. Default: `2.0`.
#' @param variable_weights Optional vector of variable weights for splitting (default: equal weights).
#' @param feature_types Optional vector indicating feature types (0 for continuous, 1 for categorical; default: all continuous).
#' @export


cloglog_ordinal_bart <- function(X, y, X_test = NULL,
n_trees = 50,
n_samples_mcmc = 500,
n_burnin = 250,
n_thin = 1,
alpha_gamma = 2.0,
beta_gamma = 2.0,
variable_weights = NULL,
feature_types = NULL,
seed = NULL) {

# BART parameters
alpha_bart <- 0.95
beta_bart <- 2
min_samples_in_leaf <- 5
max_depth <- 10
scale_leaf <- 2 / sqrt(n_trees)
cutpoint_grid_size <- 100 # Needed for stochtree:::sample_mcmc_one_iteration_cpp (for GFR), not used in ordinal BART

# Fixed for identifiability (can be pass as argument later if desired)
gamma_0 = 0.0 # First gamma cutpoint fixed at gamma_0 = 0

# Determine whether a test dataset is provided
has_test <- !is.null(X_test)
# Data checks
if (!is.matrix(X)) X <- as.matrix(X)
if (!is.numeric(y)) y <- as.numeric(y)
if (has_test && !is.matrix(X_test)) X_test <- as.matrix(X_test)

n_samples <- nrow(X)
n_features <- ncol(X)

if (any(y < 1) || any(y != round(y))) {
stop("Ordinal outcome y must contain positive integers starting from 1")
}

# Convert from 1-based (R) to 0-based (C++) indexing
ordinal_outcome <- as.integer(y - 1)
n_levels <- max(y) # Number of ordinal categories

if (n_levels < 2) {
stop("Ordinal outcome must have at least 2 categories")
}

if (is.null(variable_weights)) {
variable_weights <- rep(1.0, n_features)
}

if (is.null(feature_types)) {
feature_types <- rep(0L, n_features)
}

if (!is.null(seed)) {
set.seed(seed)
}

keep_idx <- seq((n_burnin + 1), n_samples_mcmc, by = n_thin)
n_keep <- length(keep_idx)

forest_pred_train <- matrix(0, n_samples, n_keep)
if (has_test) {
n_samples_test <- nrow(X_test)
forest_pred_test <- matrix(0, n_samples_test, n_keep)
}
gamma_samples <- matrix(0, n_levels - 1, n_keep)
latent_samples <- matrix(0, n_samples, n_keep)

# Initialize other model structures as before
dataX <- stochtree::createForestDataset(X)
if (has_test) {
dataXtest <- stochtree::createForestDataset(X_test)
}
outcome_data <- stochtree::createOutcome(as.numeric(ordinal_outcome))
active_forest <- stochtree::createForest(as.integer(n_trees), 1L, TRUE, FALSE) # Use constant leaves
active_forest$set_root_leaves(0.0)
split_prior <- stochtree:::tree_prior_cpp(alpha_bart, beta_bart, min_samples_in_leaf, max_depth)
forest_samples <- stochtree::createForestSamples(as.integer(n_trees), 1L, TRUE, FALSE) # Use constant leaves
forest_tracker <- stochtree:::forest_tracker_cpp(
dataX$data_ptr,
as.integer(feature_types),
as.integer(n_trees),
as.integer(n_samples)
)
stochtree:::ordinal_aux_data_initialize_cpp(forest_tracker, as.integer(n_samples), as.integer(n_levels))

# Initialize gamma parameters to zero (slot 2)
initial_gamma <- rep(0.0, n_levels - 1)
for (i in seq_along(initial_gamma)) {
stochtree:::ordinal_aux_data_set_cpp(forest_tracker, 2, i - 1, initial_gamma[i])
}
stochtree:::ordinal_aux_data_update_cumsum_exp_cpp(forest_tracker)

# Initialize forest predictions slot to zero (slot 1)
for (i in 1:n_samples) {
stochtree:::ordinal_aux_data_set_cpp(forest_tracker, 1, i - 1, 0.0)
}

ordinal_sampler <- stochtree:::ordinal_sampler_cpp()
rng <- stochtree::createCppRNG(if (is.null(seed)) sample.int(.Machine$integer.max, 1) else seed)

# Set up sweep indices for tree updates (sample all trees each iteration)
sweep_indices <- as.integer(seq(0, n_trees - 1))

sample_counter <- 0
for (i in 1:n_samples_mcmc) {
keep_sample <- i %in% keep_idx
if (keep_sample) {
sample_counter <- sample_counter + 1
}

# 1. Sample forest using MCMC
stochtree:::sample_mcmc_one_iteration_cpp(
dataX$data_ptr, outcome_data$data_ptr, forest_samples$forest_container_ptr,
active_forest$forest_ptr, forest_tracker, split_prior, rng$rng_ptr,
sweep_indices, as.integer(feature_types), as.integer(cutpoint_grid_size),
scale_leaf, variable_weights, alpha_gamma, beta_gamma, 1.0, 4L, keep_sample
)

# Set auxiliary data slot 1 to current forest predictions = lambda_hat = sum of all the tree predictions
# This is needed for updating gamma parameters, latent z_i's
forest_pred_current <- active_forest$predict(dataX)
for (j in 1:n_samples) {
stochtree:::ordinal_aux_data_set_cpp(forest_tracker, 1, j - 1, forest_pred_current[j])
}

# 2. Sample latent z_i's using truncated exponential
stochtree:::ordinal_sampler_update_latent_variables_cpp(
ordinal_sampler, dataX$data_ptr, outcome_data$data_ptr, forest_tracker, rng$rng_ptr
)

# 3. Sample gamma parameters
stochtree:::ordinal_sampler_update_gamma_params_cpp(
ordinal_sampler, dataX$data_ptr, outcome_data$data_ptr, forest_tracker,
alpha_gamma, beta_gamma, gamma_0, rng$rng_ptr
)

# 4. Update cumulative sum of exp(gamma) values
stochtree:::ordinal_sampler_update_cumsum_exp_cpp(ordinal_sampler, forest_tracker)

if (keep_sample) {
forest_pred_train[, sample_counter] <- active_forest$predict(dataX)
if (has_test) {
forest_pred_test[, sample_counter] <- active_forest$predict(dataXtest)
}
gamma_current <- stochtree:::ordinal_aux_data_get_vector_cpp(forest_tracker, 2)
gamma_samples[, sample_counter] <- gamma_current
latent_current <- stochtree:::ordinal_aux_data_get_vector_cpp(forest_tracker, 0)
latent_samples[, sample_counter] <- latent_current
}
}

result <- list(
forest_predictions_train = forest_pred_train,
forest_predictions_test = if (has_test) forest_pred_test else NULL,
gamma_samples = gamma_samples,
latent_samples = latent_samples,
scale_leaf = scale_leaf,
ordinal_outcome = ordinal_outcome,
n_trees = n_trees,
n_keep = n_keep
)

class(result) <- "cloglog_ordinal_bart"
return(result)
}
48 changes: 44 additions & 4 deletions R/cpp11.R
Original file line number Diff line number Diff line change
Expand Up @@ -624,12 +624,12 @@ compute_leaf_indices_cpp <- function(forest_container, covariates, forest_nums)
.Call(`_stochtree_compute_leaf_indices_cpp`, forest_container, covariates, forest_nums)
}

sample_gfr_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, sweep_indices, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, num_features_subsample, num_threads) {
invisible(.Call(`_stochtree_sample_gfr_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, sweep_indices, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, num_features_subsample, num_threads))
sample_gfr_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, sweep_indices, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, num_features_subsample) {
invisible(.Call(`_stochtree_sample_gfr_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, sweep_indices, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, num_features_subsample))
}

sample_mcmc_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, sweep_indices, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, num_threads) {
invisible(.Call(`_stochtree_sample_mcmc_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, sweep_indices, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, num_threads))
sample_mcmc_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, sweep_indices, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest) {
invisible(.Call(`_stochtree_sample_mcmc_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, sweep_indices, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest))
}

sample_sigma2_one_iteration_cpp <- function(residual, dataset, rng, a, b) {
Expand Down Expand Up @@ -692,6 +692,46 @@ sample_without_replacement_integer_cpp <- function(population_vector, sampling_p
.Call(`_stochtree_sample_without_replacement_integer_cpp`, population_vector, sampling_probs, sample_size)
}

ordinal_aux_data_initialize_cpp <- function(tracker_ptr, num_observations, n_levels) {
invisible(.Call(`_stochtree_ordinal_aux_data_initialize_cpp`, tracker_ptr, num_observations, n_levels))
}

ordinal_aux_data_get_cpp <- function(tracker_ptr, type_idx, obs_idx) {
.Call(`_stochtree_ordinal_aux_data_get_cpp`, tracker_ptr, type_idx, obs_idx)
}

ordinal_aux_data_set_cpp <- function(tracker_ptr, type_idx, obs_idx, value) {
invisible(.Call(`_stochtree_ordinal_aux_data_set_cpp`, tracker_ptr, type_idx, obs_idx, value))
}

ordinal_aux_data_get_vector_cpp <- function(tracker_ptr, type_idx) {
.Call(`_stochtree_ordinal_aux_data_get_vector_cpp`, tracker_ptr, type_idx)
}

ordinal_aux_data_set_vector_cpp <- function(tracker_ptr, type_idx, values) {
invisible(.Call(`_stochtree_ordinal_aux_data_set_vector_cpp`, tracker_ptr, type_idx, values))
}

ordinal_aux_data_update_cumsum_exp_cpp <- function(tracker_ptr) {
invisible(.Call(`_stochtree_ordinal_aux_data_update_cumsum_exp_cpp`, tracker_ptr))
}

ordinal_sampler_cpp <- function() {
.Call(`_stochtree_ordinal_sampler_cpp`)
}

ordinal_sampler_update_latent_variables_cpp <- function(sampler_ptr, data_ptr, outcome_ptr, tracker_ptr, rng_ptr) {
invisible(.Call(`_stochtree_ordinal_sampler_update_latent_variables_cpp`, sampler_ptr, data_ptr, outcome_ptr, tracker_ptr, rng_ptr))
}

ordinal_sampler_update_gamma_params_cpp <- function(sampler_ptr, data_ptr, outcome_ptr, tracker_ptr, alpha_gamma, beta_gamma, gamma_0, rng_ptr) {
invisible(.Call(`_stochtree_ordinal_sampler_update_gamma_params_cpp`, sampler_ptr, data_ptr, outcome_ptr, tracker_ptr, alpha_gamma, beta_gamma, gamma_0, rng_ptr))
}

ordinal_sampler_update_cumsum_exp_cpp <- function(sampler_ptr, tracker_ptr) {
invisible(.Call(`_stochtree_ordinal_sampler_update_cumsum_exp_cpp`, sampler_ptr, tracker_ptr))
}

init_json_cpp <- function() {
.Call(`_stochtree_init_json_cpp`)
}
Expand Down
4 changes: 4 additions & 0 deletions include/stochtree/category_tracker.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,12 @@
#include <stochtree/log.h>
#include <stochtree/meta.h>

#include <cmath>
#include <map>
#include <numeric>
#include <random>
#include <set>
#include <string>
#include <vector>

namespace StochTree {
Expand Down
4 changes: 4 additions & 0 deletions include/stochtree/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,22 @@
#include <stochtree/log.h>

#include <algorithm>
#include <chrono>
#include <cmath>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <functional>
#include <iomanip>
#include <iterator>
#include <limits>
#include <map>
#include <memory>
#include <sstream>
#include <string>
#include <type_traits>
#include <unordered_map>
#include <utility>
#include <vector>

Expand Down
5 changes: 5 additions & 0 deletions include/stochtree/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@
#include <nlohmann/json.hpp>
#include <stochtree/tree.h>

#include <algorithm>
#include <deque>
#include <fstream>
#include <optional>
#include <random>
#include <unordered_map>

namespace StochTree {

Expand Down
2 changes: 2 additions & 0 deletions include/stochtree/cutpoint_candidates.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
#include <stochtree/meta.h>
#include <stochtree/partition_tracker.h>

#include <tuple>

namespace StochTree {

/*! \brief Computing and tracking cutpoints available for a given feature at a given node
Expand Down
1 change: 1 addition & 0 deletions include/stochtree/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <stochtree/io.h>
#include <stochtree/log.h>
#include <stochtree/meta.h>
#include <memory>

namespace StochTree {

Expand Down
6 changes: 6 additions & 0 deletions include/stochtree/ensemble.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@
#include <stochtree/tree.h>
#include <nlohmann/json.hpp>

#include <algorithm>
#include <deque>
#include <optional>
#include <random>
#include <unordered_map>

using json = nlohmann::json;

namespace StochTree {
Expand Down
2 changes: 2 additions & 0 deletions include/stochtree/io.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,12 @@
#include <cstdlib>
#include <cstring>
#include <functional>
#include <iostream>
#include <memory>
#include <sstream>
#include <string>
#include <thread>
#include <unordered_map>
#include <utility>
#include <vector>

Expand Down
Loading
Loading