---
title: "Scalar summaries with wrapper"
output:
  rmarkdown::html_vignette: default
vignette: >
  %\VignetteIndexEntry{Scalar summaries with wrapper}
  %\VignetteEngine{knitr::rmarkdown}
  %\VignetteEncoding{UTF-8}
---

```{r ws-knit-opts, include = FALSE}
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>",
  fig.width = 6, fig.height = 4,
  fig.align = "center"
)
```

```{r setup}
library(MetaHunt)
set.seed(1)
```

## Why scalar summaries

`metahunt()` predicts a function on the grid; conformal routines can
return a band at every grid point. Often, though, the inferential
target is a single number derived from that function:

- The average treatment effect (mean of a CATE function over a
  reference patient distribution).
- The treatment effect at a specific patient profile.
- The fraction of the population with a positive treatment effect.
- A contrast between two endpoints.

For all of these, MetaHunt accepts a `wrapper` argument that
**collapses the predicted function to a scalar** before any further
calculation. The same wrapper is applied identically to predictions
and to calibration residuals, so conformal coverage transfers
directly to the scalar summary.

## The wrapper protocol

`apply_wrapper(F_mat, wrapper, grid_weights)` defines the contract.

- `F_mat` is an `n`-by-`G_grid` numeric matrix; row `j` is one
  function on the grid.
- If `wrapper` is `NULL`, `apply_wrapper()` returns the weighted mean
  of each row using `grid_weights` (uniform `1/G_grid` by default),
  divided by `sum(grid_weights)`.
- If `wrapper` is a function, `apply_wrapper()` calls `apply(F_mat,
  1, wrapper)`, which means **the wrapper receives a single numeric
  vector of length `G_grid`** — one row of `F_mat` at a time — and
  must return a single numeric value.

The contract therefore is:

```
wrapper :: numeric vector of length G_grid  ->  numeric scalar
```

Any function satisfying that signature is a valid wrapper. The
package then enforces post-hoc that the result is numeric and has
exactly one entry per row.

## An ATE example with `grf::causal_forest`

We simulate a multi-site clinical trial with `m = 8` sites. Each
site has its own individual-level data $(Y, X, T)$ where $Y$ is a
continuous outcome, $X$ is a single patient covariate (`age`), and
$T$ is binary treatment. The site-level CATE function
$\tau^{(i)}(\text{age}) = E[Y(1) - Y(0) \mid \text{age}, \text{site} = i]$
varies across sites in a way that depends on the site's metadata.
Each site fits its own `grf::causal_forest` on its individual-level
data, and shares only the fitted model — not the patient data —
with us.

```{r ws-simulate-trials, eval = requireNamespace("grf", quietly = TRUE)}
m <- 8
n_per_site <- 200
G <- 30

W <- data.frame(
  year        = sample(2010:2020, m, replace = TRUE),
  pct_treated = round(runif(m, 0.3, 0.6), 2)
)

site_data_list <- lapply(seq_len(m), function(i) {
  age <- runif(n_per_site, 30, 80)
  T   <- rbinom(n_per_site, 1, W$pct_treated[i])
  site_eff <- (W$year[i] - 2015) / 5   # site-level shift in CATE
  tau_age  <- 0.02 * (age - 50) + site_eff
  Y0  <- 0.01 * age + rnorm(n_per_site, sd = 0.5)
  Y1  <- Y0 + tau_age
  Y   <- ifelse(T == 1, Y1, Y0)
  data.frame(Y = Y, age = age, T = T)
})

grid <- data.frame(age = seq(30, 80, length.out = G))
```

Each site fits its own `causal_forest`. We use `num.trees = 200` to
keep the vignette fast; in practice you would use the default 2000
or more.

```{r ws-fit-cf, eval = requireNamespace("grf", quietly = TRUE)}
cf_models <- lapply(site_data_list, function(d)
  grf::causal_forest(X = matrix(d$age, ncol = 1),
                     Y = d$Y,
                     W = d$T,
                     num.trees = 200))
```

We stack the per-site CATE estimates on the shared `age` grid into
the `m`-by-`G` matrix `F_hat`. Here we pass an explicit `predict_fn`
to illustrate the general pattern; the dispatch table inside
`f_hat_from_models()` already knows how to call `causal_forest`, so
for users on standard `grf::causal_forest`, the default `predict_fn`
is sufficient and you can omit the `predict_fn` argument.

```{r ws-build-fhat, eval = requireNamespace("grf", quietly = TRUE)}
cate_predict <- function(model, grid) {
  as.numeric(stats::predict(model, newdata = matrix(grid$age, ncol = 1))$predictions)
}
F_hat <- f_hat_from_models(cf_models, grid, predict_fn = cate_predict)
dim(F_hat)
```

We now fit `metahunt()` on `(F_hat, W)` and ask for the predicted
ATE at a hypothetical new site.

```{r ws-fit-metahunt, eval = requireNamespace("grf", quietly = TRUE)}
fit <- metahunt(F_hat, W, K = 3, dfspa_args = list(denoise = FALSE))
W_new <- data.frame(year = 2018, pct_treated = 0.45)
ate_pred <- predict(fit, newdata = W_new, wrapper = mean)
ate_pred
```

The scalar `ate_pred` is the predicted average treatment effect for
a hypothetical new site with metadata `(year = 2018, pct_treated =
0.45)`, taking the unweighted mean over the 30-point age grid.

## Three custom wrappers

Below are three short, self-contained wrappers, each illustrating a
different idea. All three are applied to the `F_hat`, `fit`, and
`W_new` constructed in the previous section.

### Plain `mean`

`mean` is already a function `numeric -> numeric`, so it is a valid
wrapper. With a uniform grid this is just the unweighted average of
the function over the grid — i.e. the grid-uniform ATE.

```{r wrapper-mean, eval = requireNamespace("grf", quietly = TRUE)}
predict(fit, newdata = W_new, wrapper = mean)
```

### Restricted positive mean

Suppose we only credit treatment effects that are positive (for
example, in a cost-effectiveness setting). The wrapper averages
`max(f(x), 0)` over the grid:

```{r wrapper-restricted, eval = requireNamespace("grf", quietly = TRUE)}
restricted_pos_mean <- function(f) sum(pmax(f, 0)) / length(f)
predict(fit, newdata = W_new, wrapper = restricted_pos_mean)
```

Because every row of `F_mat` is passed in turn, `f` inside the
wrapper is just a numeric vector of length `G_grid`. `length(f)` is
therefore the grid size, and dividing by it gives a uniform-weighted
average.

### Endpoint contrast

The difference `f(x_G) - f(x_1)` is a useful summary when the grid
is ordered (e.g. age, dose, or time). For our age grid it is the
gap in CATE between an 80-year-old and a 30-year-old patient at the
new site:

```{r wrapper-endpoint, eval = requireNamespace("grf", quietly = TRUE)}
endpoint_contrast <- function(f) f[length(f)] - f[1]
predict(fit, newdata = W_new, wrapper = endpoint_contrast)
```

## Conformal coverage with a wrapper

When you pass `wrapper` into `split_conformal()` (or
`cross_conformal()`, or `conformal_from_fit()`), conformity scores
are computed *after* the wrapper, on a **single shared quantile**.
The interval covers the wrapped scalar with the nominal level — not
the underlying function pointwise.

With only `m = 8` sites, we hold out a single site (the 8th) and
use the other seven for training plus calibration. The calibration
set is small, so we use `alpha = 0.1` rather than `0.05`.

```{r ws-split-scalar, eval = requireNamespace("grf", quietly = TRUE)}
# Use 7 sites for training+calibration, predict for the held-out 8th
tr_cal <- 1:7; new <- 8
res <- split_conformal(
  F_hat[tr_cal, , drop = FALSE],
  W[tr_cal, , drop = FALSE],
  W[new, , drop = FALSE],
  K = 3, wrapper = mean, alpha = 0.1, cal_frac = 0.5, seed = 1,
  dfspa_args = list(denoise = FALSE)
)
data.frame(prediction = res$prediction,
           lower      = res$lower,
           upper      = res$upper)
```

With only 8 sites in this realistic example, an empirical-coverage
check on a single held-out site is not informative — for coverage
diagnostics, use a leave-one-out loop or simulate a larger study
count. See `?coverage` for the helper function and the
`conformal-prediction` vignette for split-conformal at scale.

## Pointwise vs scalar — quick reference

| Aspect              | Pointwise (`wrapper = NULL`)                       | Scalar (`wrapper` supplied)                         |
|---------------------|----------------------------------------------------|-----------------------------------------------------|
| Output shape        | `nrow(W_new)` x `G_grid` matrix                    | length-`nrow(W_new)` numeric vector                 |
| Conformal quantile  | one per grid point (length-`G_grid`)               | a single scalar                                     |
| Coverage guarantee  | per grid point, marginally (not joint over grid)   | for the scalar summary, marginally                  |
| Best for            | visualising the predicted function with a band     | reporting a single number with a valid CI           |
| Example call        | `split_conformal(F, W, W_new, K = 3)`              | `split_conformal(F, W, W_new, K = 3, wrapper = mean)` |

A pointwise band is a visualisation aid; a scalar interval is the
right object for an inferential claim about a specific functional.
Pick the wrapper that matches the question you actually want to
answer, and let the conformal machinery do the rest.

## See also

- `vignette("data-prep")` — building `F_hat` from per-site fitted
  models (including the `grf::causal_forest` dispatch and the
  `predict_fn` escape hatch used here).
- `vignette("conformal-prediction")` — split- and cross-conformal
  routines at scale, including empirical-coverage diagnostics that
  need more than a handful of held-out sites.
