Skip to contents

We compare R2P with R2P+ for subgroup detection in regression data. For each data set and subgroup detection method, we measure heterogeneity across groups with the variance of the subgroup means (VaracVar_{ac}) and homogeneity within groups with the mean subgroup variance (VarinVar_{in}). For well-identified subgroups, VaracVar_{ac} should be high and VarinVar_{in} should be low. Further, we are interested in the mean interval width across groups and the total number of subgroups identified, respectively. Specifically, we are interested in the performance for small data set with n=100n = 100, as R2P and R2P+ can be expected to perform more similarly the larger a data set is. Therefore, we sample 100 observations from each data set 500 times, and aggregate the results. The aggregation also alleviates the randomness induced by the CV splits in R2P (split conformal prediction) and R2P+ (CV+).

Data

For this experiment, we use 6 regression data sets. Data sets are retrieved from OpenML and identified via their unique ID. A detailed description can be found on the corresponding OpenML website entry.

library("mlr3oml")
# Bikes (n = 727, k = 10, target = "count")
data(bikes, package = "conftree")
# Abalone (n = 6497, k = 8, target "rings")
abalone <- as.data.frame(odt(id = 44956)$data)
# Diamonds (n = 53940, k = 10, target "price")
diamonds <- as.data.frame(odt(id = 42225)$data)
# Elevators (n = 16599, k = 18, target "Goal")
elevators <- as.data.frame(odt(id = 216)$data)
elevators[1:16] <- lapply(elevators[1:16], as.numeric)
# Miami Housing (n = 13932, k = 15, target "SALE_PRC")
miami <- as.data.frame(odt(id = 43093)$data[,-c(3)])
miami[1:16] <- lapply(miami[1:16], as.numeric)
# Wines (n = 6497, k = 11, target "quality")
wines <- as.data.frame(odt(id = 287)$data)

We rearrange each data set such that the target column is the last column in the data.

abalone <- abalone[c(setdiff(names(abalone), "rings"), "rings")]
diamonds <- diamonds[c(setdiff(names(diamonds), "price"), "price")]
elevators <- elevators[c(setdiff(names(elevators), "Goal"), "Goal")]
miami <- miami[c(setdiff(names(miami), "SALE_PRC"), "SALE_PRC")]
wines <- wines[c(setdiff(names(wines), "quality"), "quality")]

datasets <- list("abalone" = abalone, "bikes" = bikes, "diamonds" = diamonds, "elevators" = elevators, "miami" = miami, "wines" = wines)

We standardize the target variables so results are more comparable across data sets.

datasets <- list("abalone" = abalone, "bikes" = bikes, "diamonds" = diamonds, "elevators" = elevators, "miami" = miami, "wines" = wines)
datasets <- lapply(datasets, function(df) {
  last <- ncol(df)
  df[[last]] <- (df[[last]] - mean(df[[last]], na.rm = TRUE)) / sd(df[[last]], na.rm = TRUE)
  df
})

Experiments

Setup

We use a random forest base learner and set the R2P/R2P+ hyperparameters to λ=0.5\lambda = 0.5, γ=0.01\gamma = 0.01 and max_groups is left at the default of 1010. In R2P+, we use cv_folds = 20. Further, we set α=0.2\alpha = 0.2, which means at least 4 observations must be placed in any potential subgroup. Note that setting α\alpha to this relatively large value allows us to study R2P and R2P+ in a comparable setting, since for n=100n = 100 we do not want to limit the ability to detect subgroups by the theoretical requirements of the conformal guarantees, which asks α1nl+1\alpha \geq \frac{1}{n_{l}+1} for a subgroup of size nln_l.

process_data <- function(data) {
  
  ## Learner
  forest <- rand_forest() %>%
     set_mode("regression") %>%
     set_engine("ranger")
  ## R2P
  g_r2p <- r2p(
  data = data,
  target = colnames(data)[ncol(data)],
  learner = forest,
  cv_folds = 1,
  alpha = 0.2,
  gamma = 0.01,
  lambda = 0.5,
  max_groups = 10
)

  ## Learner
  forest <- rand_forest() %>%
     set_mode("regression") %>%
     set_engine("ranger")
  ## R2P+
  g_r2pp <- r2p(
  data = data,
  target = colnames(data)[ncol(data)],
  learner = forest,
  cv_folds = 20,
  alpha = 0.2,
  gamma = 0.01,
  lambda = 0.5,
  max_groups = 10
)
  list(g_r2p, g_r2pp)
}

We define functions to run the experiments and aggregate the results. For each of the four measures, we report means and standard deviations for R2P and R2P+, respectively.

run_data <- function(data, rep, n_sample) {
  # Initialize result variables
  n_groups_r2p <- NULL
  n_groups_r2pp <- NULL
  avg_width_r2p <- NULL
  avg_width_r2pp <- NULL
  vac_r2p <- NULL
  vac_r2pp <- NULL
  vin_r2p <- NULL
  vin_r2pp <- NULL
  
  # Loop through repetitions
  for (i in seq_len(rep)) {
    sample_s <- data[sample(nrow(data), n_sample), ]
    pis <- process_data(sample_s)
    n_groups_r2p[i] <- pis[[1]]$info$n_groups
    n_groups_r2pp[i] <- pis[[2]]$info$n_groups
    group_ids_r2p <- nodeids(pis[[1]]$tree, terminal = TRUE)
    group_ids_r2pp <- nodeids(pis[[2]]$tree, terminal = TRUE)
    avg_width_r2p[i] <- mean(as.numeric(tree_width(pis[[1]]$tree,
                                                    pis[[1]]$valid_set,
                                                    pis[[1]]$info$alpha))[group_ids_r2p])
    avg_width_r2pp[i] <- mean(as.numeric(tree_width(pis[[2]]$tree,
                                                    pis[[2]]$valid_set,
                                                    pis[[2]]$info$alpha))[group_ids_r2pp])
    vac_r2p[i] <- pis[[1]]$info$var_ac
    vac_r2pp[i] <- pis[[2]]$info$var_ac
    vin_r2p[i] <- pis[[1]]$info$var_in
    vin_r2pp[i] <- pis[[2]]$info$var_in
  }
  
  # Create data frames to store results
  df_r2p <- data.frame(
    n_groups = n_groups_r2p,
    avg_width = avg_width_r2p,
    var_ac = vac_r2p,
    var_in = vin_r2p
  )
  
  df_r2pp <- data.frame(
    n_groups = n_groups_r2pp,
    avg_width = avg_width_r2pp,
    var_ac = vac_r2pp,
    var_in = vin_r2pp
  )
  
  # Return both data frames
  list("r2p" = df_r2p, "r2p_plus" = df_r2pp)
  
}

agg_data <- function(results) {
  lapply(results, function(x) {
    means <- colMeans(x, na.rm = TRUE)
    sds <- apply(x, 2, function(x) {sd(x, na.rm = TRUE)})
    data.frame(mean = means, sd = sds)
  })
}

Run

set.seed(123)
final_res <- lapply(datasets, function(x) {agg_data(run_data(x, 500, 100))})

Results

final_res
#> $abalone
#> $abalone$r2p
#>                mean        sd
#> n_groups  2.9320000 1.1841370
#> avg_width 1.7192092 0.3758183
#> var_ac    0.3246851 0.2114961
#> var_in    0.7200171 0.2330315
#> 
#> $abalone$r2p_plus
#>                mean        sd
#> n_groups  4.1040000 1.7073412
#> avg_width 1.5541796 0.2754160
#> var_ac    0.3935642 0.2171481
#> var_in    0.6212842 0.2048562
#> 
#> 
#> $bikes
#> $bikes$r2p
#>                mean        sd
#> n_groups  3.8220000 1.5293770
#> avg_width 1.1479217 0.2350975
#> var_ac    0.4396868 0.2385617
#> var_in    0.5108834 0.2329907
#> 
#> $bikes$r2p_plus
#>                mean        sd
#> n_groups  6.5420000 1.9452001
#> avg_width 0.8839908 0.1341617
#> var_ac    0.6397187 0.1880607
#> var_in    0.2838799 0.1090213
#> 
#> 
#> $diamonds
#> $diamonds$r2p
#>                mean        sd
#> n_groups  5.0740000 1.4030185
#> avg_width 0.6886794 0.2402770
#> var_ac    0.8467990 0.4450025
#> var_in    0.2719968 0.1242952
#> 
#> $diamonds$r2p_plus
#>                mean         sd
#> n_groups  7.9900000 1.61961301
#> avg_width 0.6229672 0.16552796
#> var_ac    0.9429163 0.41577069
#> var_in    0.1881956 0.07277692
#> 
#> 
#> $elevators
#> $elevators$r2p
#>                mean        sd
#> n_groups  3.1820000 1.1364237
#> avg_width 1.5252150 0.4168255
#> var_ac    0.2075417 0.3270168
#> var_in    0.8487943 0.3133283
#> 
#> $elevators$r2p_plus
#>                mean        sd
#> n_groups  4.4960000 1.7146031
#> avg_width 1.4527899 0.3400333
#> var_ac    0.3152598 0.4022665
#> var_in    0.7816624 0.3358434
#> 
#> 
#> $miami
#> $miami$r2p
#>                mean        sd
#> n_groups  3.2000000 1.1742051
#> avg_width 0.9518659 0.2809258
#> var_ac    0.1777317 0.2689256
#> var_in    0.8893836 0.4890409
#> 
#> $miami$r2p_plus
#>                mean        sd
#> n_groups  4.8280000 1.8387981
#> avg_width 0.8086003 0.1850674
#> var_ac    0.2433055 0.2346503
#> var_in    0.7231350 0.3757193
#> 
#> 
#> $wines
#> $wines$r2p
#>                mean        sd
#> n_groups  3.1240000 1.2774536
#> avg_width 2.0571390 0.3907371
#> var_ac    0.1037108 0.1071144
#> var_in    0.8915611 0.2061902
#> 
#> $wines$r2p_plus
#>                mean        sd
#> n_groups  4.0920000 1.8639970
#> avg_width 1.9726951 0.2942497
#> var_ac    0.1380350 0.1251651
#> var_in    0.8452096 0.2142913