The goal of this vignette is to explain how to quantify the extent to which it is possible to train on one data subset, and predict on another data subset. This kind of problem occurs frequently in many different problem domains:

The ideas are similar to my previous blog posts about how to do this in python and R. Below we explain how to use mlr3resampling for this purpose, in simulated regression and classification problems. To use this method in real data, the important sections to read below are named “Benchmark: computing test error,” which show how to create these cross-validation experiments using mlr3 code.

Simulated regression problems

We begin by generating some data which can be used with regression algorithms. Assume there is a data set with some rows from one person, some rows from another,

N <- 300
library(data.table)
set.seed(1)
abs.x <- 2
reg.dt <- data.table(
  x=runif(N, -abs.x, abs.x),
  person=rep(1:2, each=0.5*N))
reg.pattern.list <- list(
  easy=function(x, person)x^2,
  impossible=function(x, person)(x^2+person*3)*(-1)^person)
reg.task.list <- list()
for(task_id in names(reg.pattern.list)){
  f <- reg.pattern.list[[task_id]]
  yname <- paste0("y_",task_id)
  reg.dt[, (yname) := f(x,person)+rnorm(N)][]
  task.dt <- reg.dt[, c("x","person",yname), with=FALSE]
  reg.task <- mlr3::TaskRegr$new(
    task_id, task.dt, target=yname)
  reg.task$col_roles$subset <- "person"
  reg.task$col_roles$stratum <- "person"
  reg.task$col_roles$feature <- "x"
  reg.task.list[[task_id]] <- reg.task
}
reg.dt
#>               x person      y_easy y_impossible
#>           <num>  <int>       <num>        <num>
#>   1: -0.9379653      1  1.32996609    -2.918082
#>   2: -0.5115044      1  0.24307692    -3.866062
#>   3:  0.2914135      1 -0.23314657    -3.837799
#>   4:  1.6328312      1  1.73677545    -7.221749
#>   5: -1.1932723      1 -0.06356159    -5.877792
#>  ---                                           
#> 296:  0.7257701      2 -2.48130642     5.180948
#> 297: -1.6033236      2  1.20453459     9.604312
#> 298: -1.5243898      2  1.89966190     7.511988
#> 299: -1.7982414      2  3.47047566    11.035397
#> 300:  1.7170157      2  0.60541972    10.719685

The table above shows some simulated data for two regression problems:

Static visualization of simulated data

First we reshape the data using the code below,

(reg.tall <- nc::capture_melt_single(
  reg.dt,
  task_id="easy|impossible",
  value.name="y"))
#>               x person    task_id           y
#>           <num>  <int>     <char>       <num>
#>   1: -0.9379653      1       easy  1.32996609
#>   2: -0.5115044      1       easy  0.24307692
#>   3:  0.2914135      1       easy -0.23314657
#>   4:  1.6328312      1       easy  1.73677545
#>   5: -1.1932723      1       easy -0.06356159
#>  ---                                         
#> 596:  0.7257701      2 impossible  5.18094849
#> 597: -1.6033236      2 impossible  9.60431191
#> 598: -1.5243898      2 impossible  7.51198770
#> 599: -1.7982414      2 impossible 11.03539747
#> 600:  1.7170157      2 impossible 10.71968480

The table above is a more convenient form for the visualization which we create using the code below,

if(require(animint2)){
  ggplot()+
    geom_point(aes(
      x, y),
      data=reg.tall)+
    facet_grid(
      task_id ~ person,
      labeller=label_both,
      space="free",
      scales="free")+
    scale_y_continuous(
      breaks=seq(-100, 100, by=2))
}
#> Loading required package: animint2

In the simulated data above, we can see that

  • for the easy pattern, it is the same for both people, so it should be possible/easy to train on one person, and accurately predict on another.
  • for the impossible pattern, it is different for each person, so it should not be possible to train on one person, and accurately predict on another.

Benchmark: computing test error

In the code below, we define a K-fold cross-validation experiment.

(reg_same_other <- mlr3resampling::ResamplingSameOtherCV$new())
#> <ResamplingSameOtherCV> : Same versus Other Cross-Validation
#> * Iterations:
#> * Instantiated: FALSE
#> * Parameters:
#> List of 1
#>  $ folds: int 3

In the code below, we define two learners to compare,

(reg.learner.list <- list(
  if(requireNamespace("rpart"))mlr3::LearnerRegrRpart$new(),
  mlr3::LearnerRegrFeatureless$new()))
#> Loading required namespace: rpart
#> [[1]]
#> <LearnerRegrRpart:regr.rpart>: Regression Tree
#> * Model: -
#> * Parameters: xval=0
#> * Packages: mlr3, rpart
#> * Predict Types:  [response]
#> * Feature Types: logical, integer, numeric, factor, ordered
#> * Properties: importance, missings, selected_features, weights
#> 
#> [[2]]
#> <LearnerRegrFeatureless:regr.featureless>: Featureless Regression Learner
#> * Model: -
#> * Parameters: robust=FALSE
#> * Packages: mlr3, stats
#> * Predict Types:  [response], se
#> * Feature Types: logical, integer, numeric, character, factor, ordered,
#>   POSIXct
#> * Properties: featureless, importance, missings, selected_features

In the code below, we define the benchmark grid, which is all combinations of tasks (easy and impossible), learners (rpart and featureless), and the one resampling method.

(reg.bench.grid <- mlr3::benchmark_grid(
  reg.task.list,
  reg.learner.list,
  reg_same_other))
#>          task          learner    resampling
#>        <char>           <char>        <char>
#> 1:       easy       regr.rpart same_other_cv
#> 2:       easy regr.featureless same_other_cv
#> 3: impossible       regr.rpart same_other_cv
#> 4: impossible regr.featureless same_other_cv

In the code below, we execute the benchmark experiment (in parallel using the multisession future plan).

if(FALSE){#for CRAN.
  if(require(future))plan("multisession")
}
if(require(lgr))get_logger("mlr3")$set_threshold("warn")
#> Loading required package: lgr
(reg.bench.result <- mlr3::benchmark(
  reg.bench.grid, store_models = TRUE))
#> <BenchmarkResult> of 72 rows with 4 resampling runs
#>  nr    task_id       learner_id resampling_id iters warnings errors
#>   1       easy       regr.rpart same_other_cv    18        0      0
#>   2       easy regr.featureless same_other_cv    18        0      0
#>   3 impossible       regr.rpart same_other_cv    18        0      0
#>   4 impossible regr.featureless same_other_cv    18        0      0

The code below computes the test error for each split,

reg.bench.score <- mlr3resampling::score(reg.bench.result)
reg.bench.score[1]
#>    train.subsets test.fold test.subset person iteration                  test
#>           <char>     <int>       <int>  <int>     <int>                <list>
#> 1:           all         1           1      1         1  1, 3, 5, 6,12,13,...
#>                    train                                uhash    nr
#>                   <list>                               <char> <int>
#> 1:  4, 7, 9,10,18,20,... 9d0598d4-4e81-4885-9be4-c6e919c8602e     1
#>               task task_id                       learner learner_id
#>             <list>  <char>                        <list>     <char>
#> 1: <TaskRegr:easy>    easy <LearnerRegrRpart:regr.rpart> regr.rpart
#>                 resampling resampling_id       prediction regr.mse algorithm
#>                     <list>        <char>           <list>    <num>    <char>
#> 1: <ResamplingSameOtherCV> same_other_cv <PredictionRegr> 1.638015     rpart

The code below visualizes the resulting test accuracy numbers.

if(require(animint2)){
  ggplot()+
    scale_x_log10()+
    geom_point(aes(
      regr.mse, train.subsets, color=algorithm),
      shape=1,
      data=reg.bench.score)+
    facet_grid(
      task_id ~ person,
      labeller=label_both,
      scales="free")
}

It is clear from the plot above that

  • for the easy task, training on same is just as good as all or other subsets. rpart has much lower test error than featureless, in all three train subsets.
  • for the impossible task, the least test error is using rpart with same train subsets; featureless with same train subsets is next best; training on all is substantially worse (for both featureless and rpart); training on other is even worse (patterns in the two people are completely different).
  • in a real data task, training on other will most likely not be quite as bad as in the impossible task above, but also not as good as in the easy task.

Interactive visualization of data, test error, and splits

The code below can be used to create an interactive data visualization which allows exploring how different functions are learned during different splits.

inst <- reg.bench.score$resampling[[1]]$instance
rect.expand <- 0.2
grid.dt <- data.table(x=seq(-abs.x, abs.x, l=101), y=0)
grid.task <- mlr3::TaskRegr$new("grid", grid.dt, target="y")
pred.dt.list <- list()
point.dt.list <- list()
for(score.i in 1:nrow(reg.bench.score)){
  reg.bench.row <- reg.bench.score[score.i]
  task.dt <- data.table(
    reg.bench.row$task[[1]]$data(),
    reg.bench.row$resampling[[1]]$instance$id.dt)
  names(task.dt)[1] <- "y"
  set.ids <- data.table(
    set.name=c("test","train")
  )[
  , data.table(row_id=reg.bench.row[[set.name]][[1]])
  , by=set.name]
  i.points <- set.ids[
    task.dt, on="row_id"
  ][
    is.na(set.name), set.name := "unused"
  ]
  point.dt.list[[score.i]] <- data.table(
    reg.bench.row[, .(task_id, iteration)],
    i.points)
  i.learner <- reg.bench.row$learner[[1]]
  pred.dt.list[[score.i]] <- data.table(
    reg.bench.row[, .(
      task_id, iteration, algorithm
    )],
    as.data.table(
      i.learner$predict(grid.task)
    )[, .(x=grid.dt$x, y=response)]
  )
}
(pred.dt <- rbindlist(pred.dt.list))
#>          task_id iteration   algorithm     x        y
#>           <char>     <int>      <char> <num>    <num>
#>    1:       easy         1       rpart -2.00 3.557968
#>    2:       easy         1       rpart -1.96 3.557968
#>    3:       easy         1       rpart -1.92 3.557968
#>    4:       easy         1       rpart -1.88 3.557968
#>    5:       easy         1       rpart -1.84 3.557968
#>   ---                                                
#> 7268: impossible        18 featureless  1.84 7.204232
#> 7269: impossible        18 featureless  1.88 7.204232
#> 7270: impossible        18 featureless  1.92 7.204232
#> 7271: impossible        18 featureless  1.96 7.204232
#> 7272: impossible        18 featureless  2.00 7.204232
(point.dt <- rbindlist(point.dt.list))
#>           task_id iteration set.name row_id           y          x  fold person
#>            <char>     <int>   <char>  <int>       <num>      <num> <int>  <int>
#>     1:       easy         1     test      1  1.32996609 -0.9379653     1      1
#>     2:       easy         1    train      2  0.24307692 -0.5115044     3      1
#>     3:       easy         1     test      3 -0.23314657  0.2914135     1      1
#>     4:       easy         1    train      4  1.73677545  1.6328312     2      1
#>     5:       easy         1     test      5 -0.06356159 -1.1932723     1      1
#>    ---                                                                         
#> 21596: impossible        18    train    296  5.18094849  0.7257701     1      2
#> 21597: impossible        18    train    297  9.60431191 -1.6033236     1      2
#> 21598: impossible        18     test    298  7.51198770 -1.5243898     3      2
#> 21599: impossible        18    train    299 11.03539747 -1.7982414     1      2
#> 21600: impossible        18     test    300 10.71968480  1.7170157     3      2
#>        subset display_row
#>         <int>       <int>
#>     1:      1           1
#>     2:      1         101
#>     3:      1           2
#>     4:      1          51
#>     5:      1           3
#>    ---                   
#> 21596:      2         198
#> 21597:      2         199
#> 21598:      2         299
#> 21599:      2         200
#> 21600:      2         300
set.colors <- c(
  train="#1B9E77",
  test="#D95F02",
  unused="white")
algo.colors <- c(
  featureless="blue",
  rpart="red")
make_person_subset <- function(DT){
  DT[, "person/subset" := person]
}
make_person_subset(point.dt)
make_person_subset(reg.bench.score)

if(require(animint2)){
  viz <- animint(
    title="Train/predict on subsets, regression",
    pred=ggplot()+
      ggtitle("Predictions for selected train/test split")+
      theme_animint(height=400)+
      scale_fill_manual(values=set.colors)+
      geom_point(aes(
        x, y, fill=set.name),
        showSelected="iteration",
        size=3,
        shape=21,
        data=point.dt)+
      scale_color_manual(values=algo.colors)+
      geom_line(aes(
        x, y, color=algorithm, subset=paste(algorithm, iteration)),
        showSelected="iteration",
        data=pred.dt)+
      facet_grid(
        task_id ~ `person/subset`,
        labeller=label_both,
        space="free",
        scales="free")+
      scale_y_continuous(
        breaks=seq(-100, 100, by=2)),
    err=ggplot()+
      ggtitle("Test error for each split")+
      theme_animint(height=400)+
      scale_y_log10(
        "Mean squared error on test set")+
      scale_fill_manual(values=algo.colors)+
      scale_x_discrete(
        "People/subsets in train set")+
      geom_point(aes(
        train.subsets, regr.mse, fill=algorithm),
        shape=1,
        size=5,
        stroke=2,
        color="black",
        color_off=NA,
        clickSelects="iteration",
        data=reg.bench.score)+
      facet_grid(
        task_id ~ `person/subset`,
        labeller=label_both,
        scales="free"),
    diagram=ggplot()+
      ggtitle("Select train/test split")+
      theme_bw()+
      theme_animint(height=300)+
      facet_grid(
        . ~ train.subsets,
        scales="free",
        space="free")+
      scale_size_manual(values=c(subset=3, fold=1))+
      scale_color_manual(values=c(subset="orange", fold="grey50"))+
      geom_rect(aes(
        xmin=-Inf, xmax=Inf,
        color=rows,
        size=rows,
        ymin=display_row, ymax=display_end),
        fill=NA,
        data=inst$viz.rect.dt)+
      scale_fill_manual(values=set.colors)+
      geom_rect(aes(
        xmin=iteration-rect.expand, ymin=display_row,
        xmax=iteration+rect.expand, ymax=display_end,
        fill=set.name),
        clickSelects="iteration",
        data=inst$viz.set.dt)+
      geom_text(aes(
        ifelse(rows=="subset", Inf, -Inf),
        (display_row+display_end)/2,
        hjust=ifelse(rows=="subset", 1, 0),
        label=paste0(rows, "=", ifelse(rows=="subset", subset, fold))),
        data=data.table(train.name="same", inst$viz.rect.dt))+
      scale_x_continuous(
        "Split number / cross-validation iteration")+
      scale_y_continuous(
        "Row number"),
    source="https://github.com/tdhock/mlr3resampling/blob/main/vignettes/ResamplingSameOtherCV.Rmd")
  viz
}

if(FALSE){
  animint2pages(viz, "2023-12-13-train-predict-subsets-regression")
}

If you are viewing this in an installed package or on CRAN, then there will be no data viz on this page, but you can view it on: https://tdhock.github.io/2023-12-13-train-predict-subsets-regression/

Simulated classification problems

The previous section investigated a simulated regression problem, whereas in this section we simulate a binary classification problem. Assume there is a data set with some rows from one person, some rows from another,

N <- 200
library(data.table)
(full.dt <- data.table(
  label=factor(rep(c("spam","not spam"), l=N)),
  person=rep(1:2, each=0.5*N)
)[, signal := ifelse(label=="not spam", 0, 3)][])
#>         label person signal
#>        <fctr>  <int>  <num>
#>   1:     spam      1      3
#>   2: not spam      1      0
#>   3:     spam      1      3
#>   4: not spam      1      0
#>   5:     spam      1      3
#>  ---                       
#> 196: not spam      2      0
#> 197:     spam      2      3
#> 198: not spam      2      0
#> 199:     spam      2      3
#> 200: not spam      2      0

Above each row has an person ID between 1 and 2. We can imagine a spam filtering system, that has training data for multiple people (here just two). Each row in the table above represents a message which has been labeled as spam or not, by one of the two people. Can we train on one person, and accurately predict on the other person? To do that we will need some features, which we generate/simulate below:

set.seed(1)
n.people <- length(unique(full.dt$person))
for(person.i in 1:n.people){
  use.signal.vec <- list(
    easy=rep(if(person.i==1)TRUE else FALSE, N),
    impossible=full.dt$person==person.i)
  for(task_id in names(use.signal.vec)){
    use.signal <- use.signal.vec[[task_id]]
    full.dt[
    , paste0("x",person.i,"_",task_id) := ifelse(
      use.signal, signal, 0
    )+rnorm(N)][]
  }
}
full.dt
#>         label person signal    x1_easy x1_impossible    x2_easy x2_impossible
#>        <fctr>  <int>  <num>      <num>         <num>      <num>         <num>
#>   1:     spam      1      3  2.3735462     3.4094018  1.0744410    -0.3410670
#>   2: not spam      1      0  0.1836433     1.6888733  1.8956548     1.5024245
#>   3:     spam      1      3  2.1643714     4.5865884 -0.6029973     0.5283077
#>   4: not spam      1      0  1.5952808    -0.3309078 -0.3908678     0.5421914
#>   5:     spam      1      3  3.3295078     0.7147645 -0.4162220    -0.1366734
#>  ---                                                                         
#> 196: not spam      2      0 -1.0479844    -0.9243128  0.7682782    -1.0293917
#> 197:     spam      2      3  4.4411577     1.5929138 -0.8161606     2.9890743
#> 198: not spam      2      0 -1.0158475     0.0450106 -0.4361069    -1.2249912
#> 199:     spam      2      3  3.4119747    -0.7151284  0.9047050     0.4038886
#> 200: not spam      2      0 -0.3810761     0.8652231 -0.7630863     1.1691226

In the table above, there are two sets of two features:

Static visualization of simulated data

Below we reshape the data to a table which is more suitable for visualization:

(scatter.dt <- nc::capture_melt_multiple(
  full.dt,
  column="x[12]",
  "_",
  task_id="easy|impossible"))
#>         label person signal    task_id         x1         x2
#>        <fctr>  <int>  <num>     <char>      <num>      <num>
#>   1:     spam      1      3       easy  2.3735462  1.0744410
#>   2: not spam      1      0       easy  0.1836433  1.8956548
#>   3:     spam      1      3       easy  2.1643714 -0.6029973
#>   4: not spam      1      0       easy  1.5952808 -0.3908678
#>   5:     spam      1      3       easy  3.3295078 -0.4162220
#>  ---                                                        
#> 396: not spam      2      0 impossible -0.9243128 -1.0293917
#> 397:     spam      2      3 impossible  1.5929138  2.9890743
#> 398: not spam      2      0 impossible  0.0450106 -1.2249912
#> 399:     spam      2      3 impossible -0.7151284  0.4038886
#> 400: not spam      2      0 impossible  0.8652231  1.1691226

Below we visualize the pattern for each person and feature type:

if(require(animint2)){
  ggplot()+
    geom_point(aes(
      x1, x2, color=label),
      shape=1,
      data=scatter.dt)+
    facet_grid(
      task_id ~ person,
      labeller=label_both)
}

In the plot above, it is apparent that

  • for easy features (left), the two label classes differ in x1 values for both people. So it should be possible/easy to train on person 1, and predict accurately on person 2.
  • for impossible features (right), the two people have different label patterns. For person 1, the two label classes differ in x1 values, whereas for person 2, the two label classes differ in x2 values. So it should be impossible to train on person 1, and predict accurately on person 2.

Benchmark: computing test error

We use the code below to create a list of classification tasks, for use in the mlr3 framework.

class.task.list <- list()
for(task_id in c("easy","impossible")){
  feature.names <- grep(task_id, names(full.dt), value=TRUE)
  task.col.names <- c(feature.names, "label", "person")
  task.dt <- full.dt[, task.col.names, with=FALSE]
  this.task <- mlr3::TaskClassif$new(
    task_id, task.dt, target="label")
  this.task$col_roles$subset <- "person"
  this.task$col_roles$stratum <- c("person","label")
  this.task$col_roles$feature <- setdiff(names(task.dt), this.task$col_roles$stratum)
  class.task.list[[task_id]] <- this.task
}
class.task.list
#> $easy
#> <TaskClassif:easy> (200 x 3)
#> * Target: label
#> * Properties: twoclass, strata
#> * Features (2):
#>   - dbl (2): x1_easy, x2_easy
#> * Strata: person, label
#> 
#> $impossible
#> <TaskClassif:impossible> (200 x 3)
#> * Target: label
#> * Properties: twoclass, strata
#> * Features (2):
#>   - dbl (2): x1_impossible, x2_impossible
#> * Strata: person, label

Note in the code above that person is assigned roles subset and stratum, whereas label is assigned roles target and stratum. When adapting the code above to real data, the important part is the mlr3::TaskClassif line which tells mlr3 what data set to use, and what columns should be used for target/subset/stratum.

The code below is used to define a K-fold cross-validation experiment,

(class_same_other <- mlr3resampling::ResamplingSameOtherCV$new())
#> <ResamplingSameOtherCV> : Same versus Other Cross-Validation
#> * Iterations:
#> * Instantiated: FALSE
#> * Parameters:
#> List of 1
#>  $ folds: int 3

The code below is used to define the learning algorithms to test,

(class.learner.list <- list(
  if(requireNamespace("rpart"))mlr3::LearnerClassifRpart$new(),
  mlr3::LearnerClassifFeatureless$new()))
#> [[1]]
#> <LearnerClassifRpart:classif.rpart>: Classification Tree
#> * Model: -
#> * Parameters: xval=0
#> * Packages: mlr3, rpart
#> * Predict Types:  [response], prob
#> * Feature Types: logical, integer, numeric, factor, ordered
#> * Properties: importance, missings, multiclass, selected_features,
#>   twoclass, weights
#> 
#> [[2]]
#> <LearnerClassifFeatureless:classif.featureless>: Featureless Classification Learner
#> * Model: -
#> * Parameters: method=mode
#> * Packages: mlr3
#> * Predict Types:  [response], prob
#> * Feature Types: logical, integer, numeric, character, factor, ordered,
#>   POSIXct
#> * Properties: featureless, importance, missings, multiclass,
#>   selected_features, twoclass

The code below defines the grid of tasks, learners, and resamplings.

(class.bench.grid <- mlr3::benchmark_grid(
  class.task.list,
  class.learner.list,
  class_same_other))
#>          task             learner    resampling
#>        <char>              <char>        <char>
#> 1:       easy       classif.rpart same_other_cv
#> 2:       easy classif.featureless same_other_cv
#> 3: impossible       classif.rpart same_other_cv
#> 4: impossible classif.featureless same_other_cv

The code below runs the benchmark experiment grid. Note that each iteration can be parallelized by declaring a future plan.

if(FALSE){
  if(require(future))plan("multisession")
}
if(require(lgr))get_logger("mlr3")$set_threshold("warn")
(class.bench.result <- mlr3::benchmark(
  class.bench.grid, store_models = TRUE))
#> <BenchmarkResult> of 72 rows with 4 resampling runs
#>  nr    task_id          learner_id resampling_id iters warnings errors
#>   1       easy       classif.rpart same_other_cv    18        0      0
#>   2       easy classif.featureless same_other_cv    18        0      0
#>   3 impossible       classif.rpart same_other_cv    18        0      0
#>   4 impossible classif.featureless same_other_cv    18        0      0

Below we compute scores (test error) for each resampling iteration, and show the first row of the result.

class.bench.score <- mlr3resampling::score(class.bench.result)
class.bench.score[1]
#>    train.subsets test.fold test.subset person iteration                  test
#>           <char>     <int>       <int>  <int>     <int>                <list>
#> 1:           all         1           1      1         1  1, 2, 8,11,12,18,...
#>                    train                                uhash    nr
#>                   <list>                               <char> <int>
#> 1:  3, 4, 5, 6, 9,10,... 18eca931-c440-4c14-bdd2-38e128d47b64     1
#>                  task task_id                             learner    learner_id
#>                <list>  <char>                              <list>        <char>
#> 1: <TaskClassif:easy>    easy <LearnerClassifRpart:classif.rpart> classif.rpart
#>                 resampling resampling_id          prediction classif.ce
#>                     <list>        <char>              <list>      <num>
#> 1: <ResamplingSameOtherCV> same_other_cv <PredictionClassif> 0.08823529
#>    algorithm
#>       <char>
#> 1:     rpart

Finally we plot the test error values below.

if(require(animint2)){
  ggplot()+
    geom_point(aes(
      classif.ce, train.subsets, color=algorithm),
      shape=1,
      data=class.bench.score)+
    facet_grid(
      person ~ task_id,
      labeller=label_both,
      scales="free")
}

It is clear from the plot above that

  • for the easy task, training on same is just as good as all or other subsets.
  • for the impossible task, we must train on same subset for minimal test error; training on all is almost as good, because the pattern in person 1 is orthogonal to person 2; training on other is just as bad as featureless, because patterns are different.
  • in a real data task, training on other will most likely not be quite as bad as in the impossible task above, but also not as good as in the easy task.

Interactive visualization of data, test error, and splits

The code below can be used to create an interactive data visualization which allows exploring how different functions are learned during different splits.

inst <- class.bench.score$resampling[[1]]$instance
rect.expand <- 0.2
grid.value.dt <- scatter.dt[
, lapply(.SD, function(x)do.call(seq, c(as.list(range(x)), l=21)))
, .SDcols=c("x1","x2")]
grid.class.dt <- data.table(
  label=full.dt$label[1],
  do.call(
    CJ, grid.value.dt
  )
)
class.pred.dt.list <- list()
class.point.dt.list <- list()
for(score.i in 1:nrow(class.bench.score)){
  class.bench.row <- class.bench.score[score.i]
  task.dt <- data.table(
    class.bench.row$task[[1]]$data(),
    class.bench.row$resampling[[1]]$instance$id.dt)
  names(task.dt)[2:3] <- c("x1","x2")
  set.ids <- data.table(
    set.name=c("test","train")
  )[
  , data.table(row_id=class.bench.row[[set.name]][[1]])
  , by=set.name]
  i.points <- set.ids[
    task.dt, on="row_id"
  ][
    is.na(set.name), set.name := "unused"
  ][]
  class.point.dt.list[[score.i]] <- data.table(
    class.bench.row[, .(task_id, iteration)],
    i.points)
  if(class.bench.row$algorithm!="featureless"){
    i.learner <- class.bench.row$learner[[1]]
    i.learner$predict_type <- "prob"
    i.task <- class.bench.row$task[[1]]
    setnames(grid.class.dt, names(i.task$data()))
    grid.class.task <- mlr3::TaskClassif$new(
      "grid", grid.class.dt, target="label")
    pred.grid <- as.data.table(
      i.learner$predict(grid.class.task)
    )[, data.table(grid.class.dt, prob.spam)]
    names(pred.grid)[2:3] <- c("x1","x2")
    pred.wide <- dcast(pred.grid, x1 ~ x2, value.var="prob.spam")
    prob.mat <- as.matrix(pred.wide[,-1])
    contour.list <- contourLines(
      grid.value.dt$x1, grid.value.dt$x2, prob.mat, levels=0.5)
    class.pred.dt.list[[score.i]] <- data.table(
      class.bench.row[, .(
        task_id, iteration, algorithm
      )],
      data.table(contour.i=seq_along(contour.list))[, {
        do.call(data.table, contour.list[[contour.i]])[, .(level, x1=x, x2=y)]
      }, by=contour.i]
    )
  }
}
(class.pred.dt <- rbindlist(class.pred.dt.list))
#>         task_id iteration algorithm contour.i level       x1        x2
#>          <char>     <int>    <char>     <int> <num>    <num>     <num>
#>   1:       easy         1     rpart         1   0.5 1.856156 -3.008049
#>   2:       easy         1     rpart         1   0.5 1.856156 -2.606579
#>   3:       easy         1     rpart         1   0.5 1.856156 -2.205109
#>   4:       easy         1     rpart         1   0.5 1.856156 -1.803639
#>   5:       easy         1     rpart         1   0.5 1.856156 -1.402169
#>  ---                                                                  
#> 766: impossible        18     rpart         1   0.5 3.743510  1.225096
#> 767: impossible        18     rpart         1   0.5 4.158037  1.225096
#> 768: impossible        18     rpart         1   0.5 4.572564  1.225096
#> 769: impossible        18     rpart         1   0.5 4.987091  1.225096
#> 770: impossible        18     rpart         1   0.5 5.401618  1.225096
(class.point.dt <- rbindlist(class.point.dt.list))
#>           task_id iteration set.name row_id    label         x1         x2
#>            <char>     <int>   <char>  <int>   <fctr>      <num>      <num>
#>     1:       easy         1     test      1     spam  2.3735462  1.0744410
#>     2:       easy         1     test      2 not spam  0.1836433  1.8956548
#>     3:       easy         1    train      3     spam  2.1643714 -0.6029973
#>     4:       easy         1    train      4 not spam  1.5952808 -0.3908678
#>     5:       easy         1    train      5     spam  3.3295078 -0.4162220
#>    ---                                                                    
#> 14396: impossible        18    train    196 not spam -0.9243128 -1.0293917
#> 14397: impossible        18    train    197     spam  1.5929138  2.9890743
#> 14398: impossible        18    train    198 not spam  0.0450106 -1.2249912
#> 14399: impossible        18    train    199     spam -0.7151284  0.4038886
#> 14400: impossible        18    train    200 not spam  0.8652231  1.1691226
#>         fold person subset display_row
#>        <int>  <int>  <int>       <int>
#>     1:     1      1      1           1
#>     2:     1      1      1           2
#>     3:     2      1      1          35
#>     4:     2      1      1          36
#>     5:     2      1      1          37
#>    ---                                
#> 14396:     2      2      2         166
#> 14397:     2      2      2         167
#> 14398:     1      2      2         133
#> 14399:     1      2      2         134
#> 14400:     2      2      2         168

set.colors <- c(
  train="#1B9E77",
  test="#D95F02",
  unused="white")
algo.colors <- c(
  featureless="blue",
  rpart="red")
make_person_subset <- function(DT){
  DT[, "person/subset" := person]
}
make_person_subset(class.point.dt)
make_person_subset(class.bench.score)
if(require(animint2)){
  viz <- animint(
    title="Train/predict on subsets, classification",
    pred=ggplot()+
      ggtitle("Predictions for selected train/test split")+
      theme_animint(height=400)+
      scale_fill_manual(values=set.colors)+
      scale_color_manual(values=c(spam="black","not spam"="white"))+
      geom_point(aes(
        x1, x2, color=label, fill=set.name),
        showSelected="iteration",
        size=3,
        stroke=2,
        shape=21,
        data=class.point.dt)+
      geom_path(aes(
        x1, x2, 
        subset=paste(algorithm, iteration, contour.i)),
        showSelected=c("iteration","algorithm"),
        color=algo.colors[["rpart"]],
        data=class.pred.dt)+
      facet_grid(
        task_id ~ `person/subset`,
        labeller=label_both,
        space="free",
        scales="free")+
      scale_y_continuous(
        breaks=seq(-100, 100, by=2)),
    err=ggplot()+
      ggtitle("Test error for each split")+
      theme_animint(height=400)+
      theme(panel.margin=grid::unit(1, "lines"))+
      scale_y_continuous(
        "Classification error on test set",
        breaks=seq(0, 1, by=0.25))+
      scale_fill_manual(values=algo.colors)+
      scale_x_discrete(
        "People/subsets in train set")+
      geom_hline(aes(
        yintercept=yint),
        data=data.table(yint=0.5),
        color="grey50")+
      geom_point(aes(
        train.subsets, classif.ce, fill=algorithm),
        shape=1,
        size=5,
        stroke=2,
        color="black",
        color_off=NA,
        clickSelects="iteration",
        data=class.bench.score)+
      facet_grid(
        task_id ~ `person/subset`,
        labeller=label_both),
    diagram=ggplot()+
      ggtitle("Select train/test split")+
      theme_bw()+
      theme_animint(height=300)+
      facet_grid(
        . ~ train.subsets,
        scales="free",
        space="free")+
      scale_size_manual(values=c(subset=3, fold=1))+
      scale_color_manual(values=c(subset="orange", fold="grey50"))+
      geom_rect(aes(
        xmin=-Inf, xmax=Inf,
        color=rows,
        size=rows,
        ymin=display_row, ymax=display_end),
        fill=NA,
        data=inst$viz.rect.dt)+
      scale_fill_manual(values=set.colors)+
      geom_rect(aes(
        xmin=iteration-rect.expand, ymin=display_row,
        xmax=iteration+rect.expand, ymax=display_end,
        fill=set.name),
        clickSelects="iteration",
        data=inst$viz.set.dt)+
      geom_text(aes(
        ifelse(rows=="subset", Inf, -Inf),
        (display_row+display_end)/2,
        hjust=ifelse(rows=="subset", 1, 0),
        label=paste0(rows, "=", ifelse(rows=="subset", subset, fold))),
        data=data.table(train.name="same", inst$viz.rect.dt))+
      scale_x_continuous(
        "Split number / cross-validation iteration")+
      scale_y_continuous(
        "Row number"),
    source="https://github.com/tdhock/mlr3resampling/blob/main/vignettes/ResamplingSameOtherCV.Rmd")
  viz
}

if(FALSE){
  animint2pages(viz, "2023-12-13-train-predict-subsets-classification")
}

If you are viewing this in an installed package or on CRAN, then there will be no data viz on this page, but you can view it on: https://tdhock.github.io/2023-12-13-train-predict-subsets-classification/

Conclusion

In this vignette we have shown how to use mlr3resampling for comparing test error of models trained on same/all/other subsets.

Session info

sessionInfo()
#> R Under development (unstable) (2024-01-23 r85822 ucrt)
#> Platform: x86_64-w64-mingw32/x64
#> Running under: Windows 10 x64 (build 19045)
#> 
#> Matrix products: default
#> 
#> 
#> locale:
#> [1] LC_COLLATE=C                          
#> [2] LC_CTYPE=English_United States.utf8   
#> [3] LC_MONETARY=English_United States.utf8
#> [4] LC_NUMERIC=C                          
#> [5] LC_TIME=English_United States.utf8    
#> 
#> time zone: America/Phoenix
#> tzcode source: internal
#> 
#> attached base packages:
#> [1] stats     graphics  grDevices utils     datasets  methods   base     
#> 
#> other attached packages:
#> [1] mlr3_0.18.0        lgr_0.4.4          animint2_2024.1.24 data.table_1.15.99
#> 
#> loaded via a namespace (and not attached):
#>  [1] future.apply_1.11.2      gtable_0.3.4             jsonlite_1.8.8          
#>  [4] highr_0.10               compiler_4.4.0           crayon_1.5.2            
#>  [7] rpart_4.1.23             Rcpp_1.0.12              stringr_1.5.1           
#> [10] parallel_4.4.0           jquerylib_0.1.4          globals_0.16.3          
#> [13] scales_1.3.0             uuid_1.2-0               RhpcBLASctl_0.23-42     
#> [16] yaml_2.3.8               fastmap_1.1.1            R6_2.5.1                
#> [19] plyr_1.8.9               labeling_0.4.3           knitr_1.46              
#> [22] palmerpenguins_0.1.1     backports_1.4.1          checkmate_2.3.1         
#> [25] future_1.33.2            munsell_0.5.1            paradox_0.11.1          
#> [28] bslib_0.7.0              mlr3measures_0.5.0       rlang_1.1.3             
#> [31] stringi_1.8.3            cachem_1.0.8             xfun_0.43               
#> [34] mlr3misc_0.15.0          sass_0.4.9               RJSONIO_1.3-1.9         
#> [37] cli_3.6.2                magrittr_2.0.3           digest_0.6.34           
#> [40] grid_4.4.0               nc_2024.2.21             lifecycle_1.0.4         
#> [43] evaluate_0.23            glue_1.7.0               farver_2.1.1            
#> [46] listenv_0.9.1            codetools_0.2-19         parallelly_1.37.1       
#> [49] colorspace_2.1-0         reshape2_1.4.4           rmarkdown_2.26          
#> [52] mlr3resampling_2024.4.14 tools_4.4.0              htmltools_0.5.8.1