dataset <- survival::colon |>
  data.table::as.data.table() |>
  na.omit()
dataset <- dataset[get("etype") == 2, ]

seed <- 123
surv_cols <- c("status", "time", "rx")

feature_cols <- colnames(dataset)[3:(ncol(dataset) - 1)]

param_list_glmnet <- expand.grid(
  alpha = seq(0, 1, .2)
)

if (isTRUE(as.logical(Sys.getenv("_R_CHECK_LIMIT_CORES_")))) {
  # on cran
  ncores <- 2L
} else {
  ncores <- ifelse(
    test = parallel::detectCores() > 4,
    yes = 4L,
    no = ifelse(
      test = parallel::detectCores() < 2L,
      yes = 1L,
      no = parallel::detectCores()
    )
  )
}

split_vector <- splitTools::multi_strata(
  df = dataset[, .SD, .SDcols = surv_cols],
  strategy = "kmeans",
  k = 4
)

train_x <- model.matrix(
  ~ -1 + .,
  dataset[, .SD, .SDcols = setdiff(feature_cols, surv_cols[1:2])]
)
train_y <- survival::Surv(
  event = (dataset[, get("status")] |>
    as.character() |>
    as.integer()),
  time = dataset[, get("time")],
  type = "right"
)

fold_list <- splitTools::create_folds(
  y = split_vector,
  k = 3,
  type = "stratified",
  seed = seed
)

options("mlexperiments.bayesian.max_init" = 4L)

# ###########################################################################
# %% TUNING
# ###########################################################################

glmnet_bounds <- list(alpha = c(0., 1.))
optim_args <- list(
  n_iter = ncores,
  kappa = 3.5,
  acq = "ucb"
)

# ###########################################################################
# %% NESTED CV
# ###########################################################################

test_that(desc = "test nested cv, grid - surv_glmnet_cox", code = {
  testthat::skip_if_not_installed("survival")
  testthat::skip_if_not_installed("glmnet")
  testthat::skip_if_not_installed("rBayesianOptimization")

  surv_glmnet_cox_optimizer <- mlexperiments::MLNestedCV$new(
    learner = LearnerSurvGlmnetCox$new(),
    strategy = "bayesian",
    fold_list = fold_list,
    k_tuning = 3L,
    ncores = ncores,
    seed = seed
  )

  surv_glmnet_cox_optimizer$parameter_bounds <- glmnet_bounds
  surv_glmnet_cox_optimizer$parameter_grid <- param_list_glmnet
  surv_glmnet_cox_optimizer$split_type <- "stratified"
  surv_glmnet_cox_optimizer$split_vector <- split_vector
  surv_glmnet_cox_optimizer$optim_args <- optim_args

  surv_glmnet_cox_optimizer$performance_metric <- c_index

  # set data
  surv_glmnet_cox_optimizer$set_data(
    x = train_x,
    y = train_y
  )

  cv_results <- surv_glmnet_cox_optimizer$execute()
  expect_type(cv_results, "list")
  expect_equal(dim(cv_results), c(3, 4))
  expect_true(inherits(
    x = surv_glmnet_cox_optimizer$results,
    what = "mlexCV"
  ))
})
