## ----ws-knit-opts, include = FALSE--------------------------------------------
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>",
  fig.width = 6, fig.height = 4,
  fig.align = "center"
)

## ----setup--------------------------------------------------------------------
library(MetaHunt)
set.seed(1)

## ----ws-simulate-trials, eval = requireNamespace("grf", quietly = TRUE)-------
m <- 8
n_per_site <- 200
G <- 30

W <- data.frame(
  year        = sample(2010:2020, m, replace = TRUE),
  pct_treated = round(runif(m, 0.3, 0.6), 2)
)

site_data_list <- lapply(seq_len(m), function(i) {
  age <- runif(n_per_site, 30, 80)
  T   <- rbinom(n_per_site, 1, W$pct_treated[i])
  site_eff <- (W$year[i] - 2015) / 5   # site-level shift in CATE
  tau_age  <- 0.02 * (age - 50) + site_eff
  Y0  <- 0.01 * age + rnorm(n_per_site, sd = 0.5)
  Y1  <- Y0 + tau_age
  Y   <- ifelse(T == 1, Y1, Y0)
  data.frame(Y = Y, age = age, T = T)
})

grid <- data.frame(age = seq(30, 80, length.out = G))

## ----ws-fit-cf, eval = requireNamespace("grf", quietly = TRUE)----------------
cf_models <- lapply(site_data_list, function(d)
  grf::causal_forest(X = matrix(d$age, ncol = 1),
                     Y = d$Y,
                     W = d$T,
                     num.trees = 200))

## ----ws-build-fhat, eval = requireNamespace("grf", quietly = TRUE)------------
cate_predict <- function(model, grid) {
  as.numeric(stats::predict(model, newdata = matrix(grid$age, ncol = 1))$predictions)
}
F_hat <- f_hat_from_models(cf_models, grid, predict_fn = cate_predict)
dim(F_hat)

## ----ws-fit-metahunt, eval = requireNamespace("grf", quietly = TRUE)----------
fit <- metahunt(F_hat, W, K = 3, dfspa_args = list(denoise = FALSE))
W_new <- data.frame(year = 2018, pct_treated = 0.45)
ate_pred <- predict(fit, newdata = W_new, wrapper = mean)
ate_pred

## ----wrapper-mean, eval = requireNamespace("grf", quietly = TRUE)-------------
predict(fit, newdata = W_new, wrapper = mean)

## ----wrapper-restricted, eval = requireNamespace("grf", quietly = TRUE)-------
restricted_pos_mean <- function(f) sum(pmax(f, 0)) / length(f)
predict(fit, newdata = W_new, wrapper = restricted_pos_mean)

## ----wrapper-endpoint, eval = requireNamespace("grf", quietly = TRUE)---------
endpoint_contrast <- function(f) f[length(f)] - f[1]
predict(fit, newdata = W_new, wrapper = endpoint_contrast)

## ----ws-split-scalar, eval = requireNamespace("grf", quietly = TRUE)----------
# Use 7 sites for training+calibration, predict for the held-out 8th
tr_cal <- 1:7; new <- 8
res <- split_conformal(
  F_hat[tr_cal, , drop = FALSE],
  W[tr_cal, , drop = FALSE],
  W[new, , drop = FALSE],
  K = 3, wrapper = mean, alpha = 0.1, cal_frac = 0.5, seed = 1,
  dfspa_args = list(denoise = FALSE)
)
data.frame(prediction = res$prediction,
           lower      = res$lower,
           upper      = res$upper)

