Experiments
experiments.Rmd
We compare R2P with R2P+ for subgroup detection in regression data. For each data set and subgroup detection method, we measure heterogeneity across groups with the variance of the subgroup means () and homogeneity within groups with the mean subgroup variance (). For well-identified subgroups, should be high and should be low. Further, we are interested in the mean interval width across groups and the total number of subgroups identified, respectively. Specifically, we are interested in the performance for small data set with , as R2P and R2P+ can be expected to perform more similarly the larger a data set is. Therefore, we sample 100 observations from each data set 500 times, and aggregate the results. The aggregation also alleviates the randomness induced by the CV splits in R2P (split conformal prediction) and R2P+ (CV+).
Data
For this experiment, we use 6 regression data sets. Data sets are retrieved from OpenML and identified via their unique ID. A detailed description can be found on the corresponding OpenML website entry.
library("mlr3oml")
# Bikes (n = 727, k = 10, target = "count")
data(bikes, package = "conftree")
# Abalone (n = 6497, k = 8, target "rings")
abalone <- as.data.frame(odt(id = 44956)$data)
# Diamonds (n = 53940, k = 10, target "price")
diamonds <- as.data.frame(odt(id = 42225)$data)
# Elevators (n = 16599, k = 18, target "Goal")
elevators <- as.data.frame(odt(id = 216)$data)
elevators[1:16] <- lapply(elevators[1:16], as.numeric)
# Miami Housing (n = 13932, k = 15, target "SALE_PRC")
miami <- as.data.frame(odt(id = 43093)$data[,-c(3)])
miami[1:16] <- lapply(miami[1:16], as.numeric)
# Wines (n = 6497, k = 11, target "quality")
wines <- as.data.frame(odt(id = 287)$data)
We rearrange each data set such that the target column is the last column in the data.
abalone <- abalone[c(setdiff(names(abalone), "rings"), "rings")]
diamonds <- diamonds[c(setdiff(names(diamonds), "price"), "price")]
elevators <- elevators[c(setdiff(names(elevators), "Goal"), "Goal")]
miami <- miami[c(setdiff(names(miami), "SALE_PRC"), "SALE_PRC")]
wines <- wines[c(setdiff(names(wines), "quality"), "quality")]
datasets <- list("abalone" = abalone, "bikes" = bikes, "diamonds" = diamonds, "elevators" = elevators, "miami" = miami, "wines" = wines)
We standardize the target variables so results are more comparable across data sets.
Experiments
Setup
We use a random forest base learner and set the R2P/R2P+
hyperparameters to
,
and max_groups
is left at the default of
.
In R2P+, we use cv_folds = 20
. Further, we set
,
which means at least 4 observations must be placed in any potential
subgroup. Note that setting
to this relatively large value allows us to study R2P and R2P+ in a
comparable setting, since for
we do not want to limit the ability to detect subgroups by the
theoretical requirements of the conformal guarantees, which asks
for a subgroup of size
.
process_data <- function(data) {
## Learner
forest <- rand_forest() %>%
set_mode("regression") %>%
set_engine("ranger")
## R2P
g_r2p <- r2p(
data = data,
target = colnames(data)[ncol(data)],
learner = forest,
cv_folds = 1,
alpha = 0.2,
gamma = 0.01,
lambda = 0.5,
max_groups = 10
)
## Learner
forest <- rand_forest() %>%
set_mode("regression") %>%
set_engine("ranger")
## R2P+
g_r2pp <- r2p(
data = data,
target = colnames(data)[ncol(data)],
learner = forest,
cv_folds = 20,
alpha = 0.2,
gamma = 0.01,
lambda = 0.5,
max_groups = 10
)
list(g_r2p, g_r2pp)
}
We define functions to run the experiments and aggregate the results. For each of the four measures, we report means and standard deviations for R2P and R2P+, respectively.
run_data <- function(data, rep, n_sample) {
# Initialize result variables
n_groups_r2p <- NULL
n_groups_r2pp <- NULL
avg_width_r2p <- NULL
avg_width_r2pp <- NULL
vac_r2p <- NULL
vac_r2pp <- NULL
vin_r2p <- NULL
vin_r2pp <- NULL
# Loop through repetitions
for (i in seq_len(rep)) {
sample_s <- data[sample(nrow(data), n_sample), ]
pis <- process_data(sample_s)
n_groups_r2p[i] <- pis[[1]]$info$n_groups
n_groups_r2pp[i] <- pis[[2]]$info$n_groups
group_ids_r2p <- nodeids(pis[[1]]$tree, terminal = TRUE)
group_ids_r2pp <- nodeids(pis[[2]]$tree, terminal = TRUE)
avg_width_r2p[i] <- mean(as.numeric(tree_width(pis[[1]]$tree,
pis[[1]]$valid_set,
pis[[1]]$info$alpha))[group_ids_r2p])
avg_width_r2pp[i] <- mean(as.numeric(tree_width(pis[[2]]$tree,
pis[[2]]$valid_set,
pis[[2]]$info$alpha))[group_ids_r2pp])
vac_r2p[i] <- pis[[1]]$info$var_ac
vac_r2pp[i] <- pis[[2]]$info$var_ac
vin_r2p[i] <- pis[[1]]$info$var_in
vin_r2pp[i] <- pis[[2]]$info$var_in
}
# Create data frames to store results
df_r2p <- data.frame(
n_groups = n_groups_r2p,
avg_width = avg_width_r2p,
var_ac = vac_r2p,
var_in = vin_r2p
)
df_r2pp <- data.frame(
n_groups = n_groups_r2pp,
avg_width = avg_width_r2pp,
var_ac = vac_r2pp,
var_in = vin_r2pp
)
# Return both data frames
list("r2p" = df_r2p, "r2p_plus" = df_r2pp)
}
agg_data <- function(results) {
lapply(results, function(x) {
means <- colMeans(x, na.rm = TRUE)
sds <- apply(x, 2, function(x) {sd(x, na.rm = TRUE)})
data.frame(mean = means, sd = sds)
})
}
Results
final_res
#> $abalone
#> $abalone$r2p
#> mean sd
#> n_groups 2.9320000 1.1841370
#> avg_width 1.7192092 0.3758183
#> var_ac 0.3246851 0.2114961
#> var_in 0.7200171 0.2330315
#>
#> $abalone$r2p_plus
#> mean sd
#> n_groups 4.1040000 1.7073412
#> avg_width 1.5541796 0.2754160
#> var_ac 0.3935642 0.2171481
#> var_in 0.6212842 0.2048562
#>
#>
#> $bikes
#> $bikes$r2p
#> mean sd
#> n_groups 3.8220000 1.5293770
#> avg_width 1.1479217 0.2350975
#> var_ac 0.4396868 0.2385617
#> var_in 0.5108834 0.2329907
#>
#> $bikes$r2p_plus
#> mean sd
#> n_groups 6.5420000 1.9452001
#> avg_width 0.8839908 0.1341617
#> var_ac 0.6397187 0.1880607
#> var_in 0.2838799 0.1090213
#>
#>
#> $diamonds
#> $diamonds$r2p
#> mean sd
#> n_groups 5.0740000 1.4030185
#> avg_width 0.6886794 0.2402770
#> var_ac 0.8467990 0.4450025
#> var_in 0.2719968 0.1242952
#>
#> $diamonds$r2p_plus
#> mean sd
#> n_groups 7.9900000 1.61961301
#> avg_width 0.6229672 0.16552796
#> var_ac 0.9429163 0.41577069
#> var_in 0.1881956 0.07277692
#>
#>
#> $elevators
#> $elevators$r2p
#> mean sd
#> n_groups 3.1820000 1.1364237
#> avg_width 1.5252150 0.4168255
#> var_ac 0.2075417 0.3270168
#> var_in 0.8487943 0.3133283
#>
#> $elevators$r2p_plus
#> mean sd
#> n_groups 4.4960000 1.7146031
#> avg_width 1.4527899 0.3400333
#> var_ac 0.3152598 0.4022665
#> var_in 0.7816624 0.3358434
#>
#>
#> $miami
#> $miami$r2p
#> mean sd
#> n_groups 3.2000000 1.1742051
#> avg_width 0.9518659 0.2809258
#> var_ac 0.1777317 0.2689256
#> var_in 0.8893836 0.4890409
#>
#> $miami$r2p_plus
#> mean sd
#> n_groups 4.8280000 1.8387981
#> avg_width 0.8086003 0.1850674
#> var_ac 0.2433055 0.2346503
#> var_in 0.7231350 0.3757193
#>
#>
#> $wines
#> $wines$r2p
#> mean sd
#> n_groups 3.1240000 1.2774536
#> avg_width 2.0571390 0.3907371
#> var_ac 0.1037108 0.1071144
#> var_in 0.8915611 0.2061902
#>
#> $wines$r2p_plus
#> mean sd
#> n_groups 4.0920000 1.8639970
#> avg_width 1.9726951 0.2942497
#> var_ac 0.1380350 0.1251651
#> var_in 0.8452096 0.2142913