diff --git a/DESCRIPTION b/DESCRIPTION index f93b69e1e..1253c0cf8 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: tune Title: Tidy Tuning Tools -Version: 1.1.1.9000 +Version: 1.1.1.9001 Authors@R: c( person("Max", "Kuhn", , "max@posit.co", role = c("aut", "cre"), comment = c(ORCID = "0000-0003-2402-136X")), @@ -27,7 +27,7 @@ Imports: GPfit, hardhat (>= 1.2.0), lifecycle (>= 1.0.0), - parsnip (>= 1.0.2), + parsnip (>= 1.1.0.9001), purrr (>= 1.0.0), recipes (>= 1.0.4), rlang (>= 1.0.2), @@ -37,7 +37,7 @@ Imports: tidyselect (>= 1.1.2), vctrs (>= 0.6.1), withr, - workflows (>= 1.0.0), + workflows (>= 1.1.3.9001), yardstick (>= 1.0.0) Suggests: C50, @@ -50,6 +50,9 @@ Suggests: testthat (>= 3.0.0), xgboost, xml2 +Remotes: + tidymodels/parsnip#955, + tidymodels/workflows#199 Config/Needs/website: pkgdown, tidymodels, kknn, doParallel, doFuture, tidyverse/tidytemplate Config/testthat/edition: 3 diff --git a/NAMESPACE b/NAMESPACE index 9540a53f7..f5196af79 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -127,6 +127,7 @@ S3method(vec_ptype2,tune_results.tune_results) S3method(vec_restore,iteration_results) S3method(vec_restore,resample_results) S3method(vec_restore,tune_results) +S3method(weight_propensity,tune_results) export(.catch_and_log) export(.catch_and_log_fit) export(.config_key_from_metrics) @@ -295,6 +296,7 @@ importFrom(hardhat,extract_workflow) importFrom(hardhat,tune) importFrom(parsnip,get_from_env) importFrom(parsnip,required_pkgs) +importFrom(parsnip,weight_propensity) importFrom(purrr,map_int) importFrom(recipes,all_outcomes) importFrom(recipes,all_predictors) diff --git a/R/weight_propensity.R b/R/weight_propensity.R new file mode 100644 index 000000000..ddc53ec7c --- /dev/null +++ b/R/weight_propensity.R @@ -0,0 +1,143 @@ +#' Helper for bridging two-stage causal fits +#' +#' @description +#' `weight_propensity()` is a helper function to more easily link the +#' propensity and outcome models in causal workflows. In the case of a +#' single model fit, as with `model_fit`s or `workflow`s, the function +#' is roughly analogous to an `augment()` method that additionally takes in +#' a propensity weighting function. For `tune_results`, the method carries +#' out this same augment-adjacent procedure on the training data underlying +#' the resampling object for each element of the analysis set. +#' +#' @inheritParams parsnip::weight_propensity.model_fit +#' +#' @inherit parsnip::weight_propensity.model_fit return +#' +#' @inherit parsnip::weight_propensity.model_fit references +#' +#' @examplesIf tune:::should_run_examples(suggests = "modeldata") +#' # load needed packages +#' library(modeldata) +#' library(parsnip) +#' library(workflows) +#' library(rsample) +#' +#' library(ggplot2) +#' library(dplyr) +#' library(purrr) +#' +#' # example data: model causal estimate for `Class` +#' two_class_dat <- two_class_dat[1:250,] +#' two_class_dat +#' +#' # see `propensity::wt_ate()` for a more realistic example +#' # of a propensity weighting function +#' silly_wt_fn <- function(.propensity, .exposure = NULL, ...) { +#' .propensity +#' } +#' +#' propensity_wf <- workflow(Class ~ B, logistic_reg()) +#' outcome_wf <- workflow(A ~ Class, linear_reg()) %>% add_case_weights(.wts) +#' +#' # single model -------------------------------------------------------------- +#' propensity_fit <- fit(propensity_wf, two_class_dat) +#' +#' two_class_weighted <- +#' weight_propensity(propensity_fit, silly_wt_fn, data = two_class_dat) +#' +#' two_class_weighted +#' +#' outcome_fit <- fit(outcome_wf, two_class_weighted) +#' +#' outcome_fit %>% extract_fit_engine() %>% coef() +#' +#' # resampled model ----------------------------------------------------------- +#' set.seed(1) +#' boots <- bootstraps(two_class_dat[1:250,], times = 100) +#' +#' res_tm <- +#' # fit the propensity model to resamples +#' fit_resamples( +#' propensity_wf, +#' resamples = boots, +#' # note `extract = identity` rather than `extract` +#' control = control_resamples(extract = identity) +#' ) %>% +#' # determine weights for outcome model based on +#' # propensity model's predictions +#' weight_propensity(silly_wt_fn) %>% +#' # fit outcome workflow using generated `.wts` +#' fit_resamples( +#' outcome_wf, +#' resamples = ., +#' # would usually `extract = tidy` here +#' control = control_resamples(extract = identity) +#' ) +#' +#' # extracts contain the properly resampled fitted workflows: +#' collect_extracts(res_tm) +#' +#' # plot the properly resampled distribution of estimates: +#' collect_extracts(res_tm) %>% +#' pull(.extracts) %>% +#' map(extract_fit_engine) %>% +#' map(coef) %>% +#' bind_rows() %>% +#' ggplot() + +#' aes(x = ClassClass2) + +#' geom_histogram() +#' +#' @name weight_propensity +#' @aliases weight_propensity.tune_results +#' @importFrom parsnip weight_propensity +#' @method weight_propensity tune_results +#' @export +weight_propensity.tune_results <- function(object, wt_fn, ...) { + if (rlang::is_missing(wt_fn) || !is.function(wt_fn)) { + cli::cli_abort("{.arg wt_fn} must be a function.") + } + + wf_1 <- purrr::pluck(object, ".extracts", 1, ".extracts", 1) + if (!inherits(wf_1, "workflow")) { + cli::cli_abort( + "{.arg object} must have been generated with the \\ + {.help [control option](tune::control_grid)} {.code extract = identity}." + ) + } + + dots <- rlang::list2(...) + if ("data" %in% names(dots)) { + cli::cli_abort( + "The {.cls tune_results} method for {.fn weight_propensity} does not take \\ + a {.arg data} argument, but one was supplied." + ) + } + + for (resample in seq_along(object$splits)) { + object$splits[[resample]] <- + augment_split( + object$splits[[resample]], + object$.extracts[[resample]]$.extracts[[1]], + wt_fn = wt_fn, + ... + ) + } + + tibble::new_tibble( + object[, c("splits", "id")], + !!!attr(object, "rset_info")$att, + class = c(attr(object, "rset_info")$att$class, "rset") + ) +} + +augment_split <- function(split, workflow, wt_fn, ...) { + split[["data"]]$..id <- seq_along(split[["data"]][[1]]) + d <- rsample::analysis(split) + d <- vctrs::vec_slice(d, !duplicated(d$..id)) + d <- weight_propensity(workflow, wt_fn, ..., data = d) + + split[["data"]][d$..id, ".wts"] <- d$.wts + split[["data"]]$..id <- NULL + + split +} diff --git a/_pkgdown.yml b/_pkgdown.yml index a0407dab9..e86978ce7 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -55,6 +55,7 @@ reference: - coord_obs_pred - conf_mat_resampled - example_ames_knn + - starts_with("weight_propensity") - title: Developer functions contents: - merge.recipe diff --git a/man/weight_propensity.Rd b/man/weight_propensity.Rd new file mode 100644 index 000000000..399c67a20 --- /dev/null +++ b/man/weight_propensity.Rd @@ -0,0 +1,119 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/weight_propensity.R +\name{weight_propensity} +\alias{weight_propensity} +\alias{weight_propensity.tune_results} +\title{Helper for bridging two-stage causal fits} +\usage{ +\method{weight_propensity}{tune_results}(object, wt_fn, ...) +} +\arguments{ +\item{object}{The object containing the model fit(s) that will generate +predictions used to calculate propensity weights. Currently, either a +\link[parsnip:fit.model_spec]{parsnip model fit}, fitted +\link[workflows:workflow]{workflow}, or +tuning results (\code{?tune::fit_resamples}) object. If a tuning result, the +object must have been generated with the control argument +(\code{?tune::control_resamples}) \code{extract = identity}.} + +\item{wt_fn}{A function used to calculate the propensity weights. The first +argument gives the predicted probability of exposure, the true value for +which is provided in the second argument. See \code{?propensity::wt_ate()} for +an example.} + +\item{...}{Additional arguments passed to \code{wt_fn}.} +} +\value{ +For \code{model_fit} and fitted \code{workflow} input, a modified version of the data +set supplied in \code{data} that contains a \code{.wts} column with class +\code{importance_weights}. For \code{tune_results} input, a modified version of the +resampling object underlying the tuning results containing a new \code{.wts} column +with propensity values corresponding to each element of the analysis set. +} +\description{ +\code{weight_propensity()} is a helper function to more easily link the +propensity and outcome models in causal workflows. In the case of a +single model fit, as with \code{model_fit}s or \code{workflow}s, the function +is roughly analogous to an \code{augment()} method that additionally takes in +a propensity weighting function. For \code{tune_results}, the method carries +out this same augment-adjacent procedure on the training data underlying +the resampling object for each element of the analysis set. +} +\examples{ +\dontshow{if (tune:::should_run_examples(suggests = "modeldata")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +# load needed packages +library(modeldata) +library(parsnip) +library(workflows) +library(rsample) + +library(ggplot2) +library(dplyr) +library(purrr) + +# example data: model causal estimate for `Class` +two_class_dat <- two_class_dat[1:250,] +two_class_dat + +# see `propensity::wt_ate()` for a more realistic example +# of a propensity weighting function +silly_wt_fn <- function(.propensity, .exposure = NULL, ...) { + .propensity +} + +propensity_wf <- workflow(Class ~ B, logistic_reg()) +outcome_wf <- workflow(A ~ Class, linear_reg()) \%>\% add_case_weights(.wts) + +# single model -------------------------------------------------------------- +propensity_fit <- fit(propensity_wf, two_class_dat) + +two_class_weighted <- + weight_propensity(propensity_fit, silly_wt_fn, data = two_class_dat) + +two_class_weighted + +outcome_fit <- fit(outcome_wf, two_class_weighted) + +outcome_fit \%>\% extract_fit_engine() \%>\% coef() + +# resampled model ----------------------------------------------------------- +set.seed(1) +boots <- bootstraps(two_class_dat[1:250,], times = 100) + +res_tm <- + # fit the propensity model to resamples + fit_resamples( + propensity_wf, + resamples = boots, + # note `extract = identity` rather than `extract` + control = control_resamples(extract = identity) + ) \%>\% + # determine weights for outcome model based on + # propensity model's predictions + weight_propensity(silly_wt_fn) \%>\% + # fit outcome workflow using generated `.wts` + fit_resamples( + outcome_wf, + resamples = ., + # would usually `extract = tidy` here + control = control_resamples(extract = identity) + ) + +# extracts contain the properly resampled fitted workflows: +collect_extracts(res_tm) + +# plot the properly resampled distribution of estimates: +collect_extracts(res_tm) \%>\% + pull(.extracts) \%>\% + map(extract_fit_engine) \%>\% + map(coef) \%>\% + bind_rows() \%>\% + ggplot() + + aes(x = ClassClass2) + + geom_histogram() +\dontshow{\}) # examplesIf} +} +\references{ +Barrett M & D'Agostino McGowan L (forthcoming). +\emph{Causal Inference in R}. \url{https://www.r-causal.org/} +} diff --git a/tests/testthat/_snaps/weight_propensity.md b/tests/testthat/_snaps/weight_propensity.md new file mode 100644 index 000000000..e246aa85f --- /dev/null +++ b/tests/testthat/_snaps/weight_propensity.md @@ -0,0 +1,42 @@ +# errors informatively with bad input + + Code + weight_propensity(res_fit_resamples_bad, silly_wt_fn) + Condition + Error in `weight_propensity()`: + ! `object` must have been generated with the control option (`?tune::control_grid()`) `extract = identity`. + +--- + + Code + weight_propensity(res_fit_resamples) + Condition + Error in `weight_propensity()`: + ! `wt_fn` must be a function. + +--- + + Code + weight_propensity(res_fit_resamples, "boop") + Condition + Error in `weight_propensity()`: + ! `wt_fn` must be a function. + +--- + + Code + weight_propensity(res_fit_resamples, function(...) { + -1L + }) + Condition + Error in `hardhat::importance_weights()`: + ! `x` can't contain negative weights. + +--- + + Code + weight_propensity(res_fit_resamples, silly_wt_fn, data = two_class_dat) + Condition + Error in `weight_propensity()`: + ! The method for `weight_propensity()` does not take a `data` argument, but one was supplied. + diff --git a/tests/testthat/test-weight_propensity.R b/tests/testthat/test-weight_propensity.R new file mode 100644 index 000000000..a14358716 --- /dev/null +++ b/tests/testthat/test-weight_propensity.R @@ -0,0 +1,186 @@ +test_that("basic functionality", { + skip_on_cran() + skip_if_not_installed("modeldata") + library(modeldata) + library(parsnip) + library(workflows) + library(rsample) + + silly_wt_fn <- function(.propensity, .exposure = NULL, ...) { + .propensity + } + + propensity_wf <- workflow(Class ~ B, logistic_reg()) + outcome_wf <- workflow(A ~ Class, linear_reg()) %>% add_case_weights(.wts) + + set.seed(1) + boots <- bootstraps(two_class_dat) + + res_fit_resamples <- + fit_resamples( + propensity_wf, + resamples = boots, + # can't be `save_workflow = TRUE`, as we need the _fitted_ workflow + control = control_resamples(extract = identity) + ) + + res_weight_propensity <- + res_fit_resamples %>% + weight_propensity(silly_wt_fn) + + # `weight_propensity()` preserves rset properties: + preserved <- c("class", "times", "apparent", "breaks", "pool") + expect_equal(attributes(boots)[preserved], attributes(res_weight_propensity)[preserved]) + expect_equal(boots$id, res_weight_propensity$id) + expect_equal( + purrr::map(boots$splits, purrr::pluck, "in_id"), + purrr::map(res_weight_propensity$splits, purrr::pluck, "in_id") + ) + expect_equal( + purrr::map(boots$splits, purrr::pluck, "data"), + purrr::map(res_weight_propensity$splits, purrr::pluck, "data") %>% + purrr::map(dplyr::select, -.wts) + ) + + # confirming that `.wts` change with every resample. this expectation + # specifically only works because `wt_fn ~= identity` + wf_fits <- purrr::map(res_fit_resamples$.extracts, purrr::pluck, ".extracts", 1) + expect_equal( + purrr::map(boots$splits, analysis) %>% + purrr::map2(wf_fits, ., predict, type = "prob") %>% + purrr::map(dplyr::pull, 2), + purrr::map(res_weight_propensity$splits, analysis) %>% + purrr::map(dplyr::pull, .wts) %>% + purrr::map(as.numeric) + ) + + # output is valid input to another call to `fit_resamples()` + res_final <- + res_weight_propensity %>% + fit_resamples(outcome_wf, resamples = .) + + expect_s3_class(res_final, "tune_results") +}) + +test_that("errors informatively with bad input", { + skip_if_not_installed("modeldata") + library(modeldata) + library(parsnip) + library(workflows) + library(rsample) + + silly_wt_fn <- function(.propensity, .exposure = NULL, ...) { + .propensity + } + + + propensity_wf <- workflow(Class ~ B, logistic_reg()) + outcome_wf <- workflow(A ~ Class, linear_reg()) %>% add_case_weights(.wts) + + set.seed(1) + boots <- bootstraps(two_class_dat) + + res_fit_resamples <- + fit_resamples( + propensity_wf, + resamples = boots, + # can't be `save_workflow = TRUE`, as we need the _fitted_ workflow + control = control_resamples(extract = identity) + ) + + res_fit_resamples_bad <- + fit_resamples(propensity_wf, resamples = boots) + + res_weight_propensity <- + res_fit_resamples %>% + weight_propensity(silly_wt_fn) + + # did not set `control` + expect_snapshot( + error = TRUE, + weight_propensity(res_fit_resamples_bad, silly_wt_fn) + ) + + # bad `wt_fn` + expect_snapshot( + error = TRUE, + weight_propensity(res_fit_resamples) + ) + + expect_snapshot( + error = TRUE, + weight_propensity(res_fit_resamples, "boop") + ) + + expect_snapshot( + error = TRUE, + weight_propensity(res_fit_resamples, function(...) {-1L}) + ) + + # mistakenly supplied `data` + expect_snapshot( + error = TRUE, + weight_propensity(res_fit_resamples, silly_wt_fn, data = two_class_dat) + ) +}) + +test_that("results match manual calculation", { + skip_on_cran() + skip_if_not_installed("modeldata") + library(modeldata) + library(parsnip) + library(workflows) + library(rsample) + library(hardhat) + + silly_wt_fn <- function(.propensity, .exposure = NULL, ...) { + .propensity + } + + propensity_wf <- workflow(Class ~ B, logistic_reg()) + outcome_wf <- workflow(A ~ Class, linear_reg()) %>% add_case_weights(.wts) + + set.seed(1) + boots <- bootstraps(two_class_dat) + + # our way: + res_tm <- + fit_resamples( + propensity_wf, + resamples = boots, + control = control_resamples(extract = identity) + ) %>% + weight_propensity(silly_wt_fn) %>% + fit_resamples( + outcome_wf, + resamples = ., + # would usually `extract = tidy` here, but don't want to + # register broom:::tidy.lm + control = control_resamples(extract = identity) + ) + + # a la r-causal: + fit_ipw <- function(split, ...) { + .df <- analysis(split) + + propensity_model <- fit(propensity_wf, .df) + + preds <- predict(propensity_model, new_data = .df, type = "prob") + .df$.wts <- importance_weights(silly_wt_fn(preds[[2]])) + + fit(outcome_wf, .df) + } + + res_rc <- purrr::map(boots$splits, fit_ipw) + + # comparison: + expect_equal( + res_tm %>% + collect_extracts() %>% + pull(.extracts) %>% + purrr::map(extract_fit_engine) %>% + purrr::map(coef), + purrr::map(res_rc, extract_fit_engine) %>% + purrr::map(coef) + ) +})