## ----dp-knit-opts, include = FALSE--------------------------------------------
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>",
  fig.width = 6, fig.height = 4,
  fig.align = "center"
)

## ----setup--------------------------------------------------------------------
library(MetaHunt)
set.seed(1)

## ----build-grid-1d------------------------------------------------------------
ref <- data.frame(age = rnorm(500, 60, 10),
                  bp  = rnorm(500, 130, 15),
                  bmi = rnorm(500, 28, 4))
grid <- build_grid(ref, n_grid = 50, seed = 1)
dim(grid)
head(grid)

## ----ranger-example, eval = FALSE---------------------------------------------
# library(ranger)
# centre_models <- lapply(centre_data_list,
#                         function(d) ranger(y ~ ., data = d))
# F_hat <- f_hat_from_models(centre_models, grid)

## ----grf-example, eval = FALSE------------------------------------------------
# library(grf)
# centre_models <- lapply(centre_data_list,
#                         function(d) causal_forest(d$X, d$Y, d$W))
# F_hat <- f_hat_from_models(centre_models, grid)

## ----lm-onramp----------------------------------------------------------------
m <- 8
centre_meta <- data.frame(
  region     = factor(sample(c("N", "S", "E", "W"), m, replace = TRUE)),
  mean_age   = round(runif(m, 50, 70)),
  pct_female = round(runif(m, 0.4, 0.6), 2)
)

# Each centre fits a quadratic on a single covariate `x`.
make_centre_data <- function(i) {
  x <- runif(80, -1, 1)
  beta <- centre_meta$mean_age[i] / 60      # toy effect of metadata
  data.frame(x = x, y = beta * x + 0.3 * x^2 + rnorm(80, sd = 0.2))
}
centre_models <- lapply(seq_len(m), function(i)
  stats::lm(y ~ poly(x, 2), data = make_centre_data(i)))

# A 1-D grid in the centres' covariate space.
grid_centres  <- data.frame(x = seq(-1, 1, length.out = 30))
F_hat_centres <- f_hat_from_models(centre_models, grid_centres)

dim(F_hat_centres)        # 8 x 30
F_hat_centres[1:3, 1:5]

## ----predict-fn-demo----------------------------------------------------------
# Toy "model" that is just a list with a slope. predict_fn evaluates it.
fake_models <- lapply(seq_len(4), function(i)
  list(slope = i / 4, intercept = 0))

custom_predict <- function(model, grid) {
  model$intercept + model$slope * grid$x
}

F_hat_custom <- f_hat_from_models(fake_models, grid_centres,
                                  predict_fn = custom_predict)
dim(F_hat_custom)         # 4 x 30

## ----multi-d-grid-------------------------------------------------------------
# A 3-covariate reference dataset and a sub-sampled grid.
ref3   <- data.frame(age = rnorm(400, 60, 10),
                     bp  = rnorm(400, 130, 15),
                     bmi = rnorm(400, 28,  4))
grid3  <- build_grid(ref3, n_grid = 25, seed = 1)
dim(grid3)

# Each centre fits an lm on (age, bp, bmi); slopes vary across centres
# so there is genuine cross-centre heterogeneity to recover.
set.seed(2)
m3 <- 8
centre_data3 <- lapply(seq_len(m3), function(i) {
  n_i <- 60
  age <- rnorm(n_i, 60, 10)
  bp  <- rnorm(n_i, 130, 15)
  bmi <- rnorm(n_i, 28, 4)
  # slopes vary across centres
  beta_age <- 0.02 + 0.03 * (i / m3)
  beta_bp  <- -0.01 + 0.02 * cos(pi * i / m3)
  beta_bmi <- 0.05 - 0.04 * (i / m3)
  y <- beta_age * age + beta_bp * bp + beta_bmi * bmi + rnorm(n_i, sd = 0.3)
  data.frame(age = age, bp = bp, bmi = bmi, y = y)
})
centre_models3 <- lapply(centre_data3, function(d) stats::lm(y ~ age + bp + bmi, data = d))

F_hat3 <- f_hat_from_models(centre_models3, grid3)
dim(F_hat3)               # 8 x 25

## ----sanity-------------------------------------------------------------------
# Right shape: m studies x G grid points.
dim(F_hat3)

# Numeric, no NA.
is.numeric(F_hat3)
anyNA(F_hat3)

# Rows look like functions of similar magnitude (large outliers can
# dominate d-fSPA's `Delta`).
summary(apply(F_hat3, 1, function(r) c(min = min(r), max = max(r))))

## ----sanity-plot--------------------------------------------------------------
matplot(grid_centres$x, t(F_hat_centres), type = "l", lty = 1,
        col = "grey50",
        xlab = "x", ylab = expression(hat(f)(x)),
        main = "Per-centre fitted functions on the shared grid")

