Skip to contents

Finding Subgroups in Treatment Effects with Conformal Trees

Usage

r2p_hte(
  data,
  target,
  treatment,
  learner,
  cv_folds = 1,
  alpha = 0.1,
  gamma = 0.01,
  lambda = 0.5,
  max_groups = 5
)

Arguments

data

(data.frame)
data set for model training and uncertainty estimation.

target

(string)
name of the target variable. The target must be a numeric variable.

treatment

(string)
name of the treatment variable. If treatment
is a factor, then the first level is treated as control and the second level as treatment indicator. If treatment
is a numeric, then zero-one encoding is assumed and "1"
treated as treatment indicator.

learner

(model_spec)
the learner for training the prediction model. See parsnip::model_spec() for details.

cv_folds

(count)
number of CV+ folds.

alpha

(proportion)
miscoverage rate.

gamma

(proportion)
regularization parameter ensuring that reduction in the impurity of the confident homogeneity is sufficiently large.

lambda

(proportion)
balance parameter, quantifying the impact of the average interval length relative to the average absolute deviation (i.e. interval width vs. average absolute deviation)

max_groups

(count)
maximum number of subgroups.

Value

The tree.

Examples

library(tidymodels)
# Synthetic example data:
library(htesim)
set.seed(12)
dgp <- dgp(p = pF_exp_x1_x2,
          m = mF_x1,
          t = tF_div_x1_x2,
          model = "normal",
          xmodel = "unif",
          sd = 1)
sim <- simulate(object = dgp,
               nsim = 500L,
               d = 4L)
# Initialize learner:
linear <- linear_reg() %>%
 set_mode("regression") %>%
 set_engine("lm")
# Detect subgroups:
groups <- r2p_hte(
  data = sim,
  target = "y",
  treatment = "trt",
  learner = linear,
  cv_folds = 500,
  alpha = 0.1,
  gamma = 0.01,
  lambda = 0.5,
  max_groups = 8
)
summary(groups)
#> Conformal tree with 3 subgroups:
#>     n mean width deviation
#> 1 258 0.25  8.47         0
#> 2  79 0.60  6.80         0
#> 3 163 0.66  9.29         0
#> ---
#> Alpha:  0.1 Lambda:  0.5 Gamma:  0.01