Advanced brms custom families: occupancy models and the flocker_data format

Jacob Socolar

2024-02-03

brms offers the option to specify models that incorporate Stan code for custom likelihood functions. brms can then fit models using these likelihoods where the distributional parameters of the custom likelihood are given linear predictors in the usual way. Relatively straightforward examples are given in the brms custom families vignette, but brms provides surprising flexibility to do fancy things with custom families. In this vignette, I show how we use this flexibility to harness brms to occupancy modeling likelihoods. I assume that the reader knows a little bit about occupancy models and brms, but not much else.

The problem

The challenge in shoehorning an occupancy model into a brms custom family is that each multiplicative term in the occupancy-model likelihood represents a closure-unit that might contain multiple repeat visits with different detection covariates. The likelihood does not factorize at the visit level and therefore must be computed unit-wise, but the linear predictor for detection needs to be computed visit-wise. How can we tell brms to compute a visit-wise detection predictor without telling it to try to compute a visit-wise log-likelihood?

Two key tricks: unlooped families and vint terms

The first trick involves the unassuming loop argument to brms::custom_family(). Defaulting to TRUE, loop controls whether or not the custom likelihood will be evaluated row-wise. The brms::custom_family documentation points out that setting loop = FALSE can enable efficient vectorized computation of the likelihood. We are going to co-opt this argument for a different purpose. We are going to perform a custom likelihood computation that has access to all rows of the data simultaneously not to enable vectorization, but to ensure that the likelihood can “see” all of the relevant visits simultaneously as it computes the unit-wise likelihood.

Let’s think this through: our likelihood function is going to ingest a one-dimensional array y of visit-wise integer response data, and then vectors of pre-computed linear predictors for two distributional parameters: occ for occupancy and det for detection. If we have \(M\) total visits to each site, then \(\frac{M-1}{M}\) of the elements of occ will be redundant (since the occupancy predictor cannot change across visits), but there will be no redundancy in y nor det.

What we need now is a likelihood function that can associate each row with the correct closure-unit. Here’s where the second trick comes in. Some likelihoods require “extra” response data that inform the likelihood without being involved in the computation of the linear predictors. The canonical example is the number of trials in a binomial response. To supply such data in custom likelihoods, brms provides the functions vint() and vreal() (for integer and real data respectively). We are going to use repeated calls to vint() to inject all of the necessary indexing information into the likelihood.

The flocker_data format for a single-season model

Suppose the model has \(N\) unique closure-units, and the maximum number of visits to any closure-unit is \(M\). We will ensure that the data are formatted such that the first \(N\) rows correspond to the first visits to each closure-unit. Then we will pass \(M\) vint terms whose first \(N\) elements each give the row indices \(i\) corresponding to the \(m\)th visit to that closure-unit, for \(m\) in \(1\) to \(M\). All subsequent elements with indices \(i > N\) are irrelevant. Note that the first of these vint arrays is redundant and contains as its first \(N\) elements simply the integers from 1 to \(N\). We include it anyway to keep the code logic transparent and avoid bugs. Moreover, it will become relevant in more advanced multi-season models where it is possible to have closure-units that receive zero visits but still are relevant to the likelihood (see Even fancier families).

To simplify the Stan code that decodes this data structure, we also pass three additional vint terms:

Thus, the likelihood function has a number of vint terms equal to three plus the maximum number of repeat visits to any site. The Stan code to decode this format depends on the number of repeat visits and is generated on-the-fly at runtime. Here’s how it looks for a dataset with a maximum of four repeat visits:

cat(flocker:::make_occupancy_single_lpmf(4))
##   real occupancy_single_lpmf(
##     array[] int y, // detection data
##     vector mu, // lin pred for detection
##     vector occ, // lin pred for occupancy. Elements after vint1[1] irrelevant.
##     array[] int vint1, // # units (n_unit). Elements after 1 irrelevant.
##     array[] int vint2, // # sampling events per unit (n_rep). Elements after vint1[1] irrelevant.
##     array[] int vint3, // Indicator for > 0 detections (Q). Elements after vint1[1] irrelevant.
##   
##   // indices for jth repeated sampling event to each unit (elements after vint1[1] irrelevant):
##     array[] int vint4,
##     array[] int vint5,
##     array[] int vint6,
##     array[] int vint7
## ) {
##   // Create array of the rep indices that correspond to each unit.
##     array[vint1[1], 4] int index_array;
##       index_array[,1] = vint4[1:vint1[1]];
##       index_array[,2] = vint5[1:vint1[1]];
##       index_array[,3] = vint6[1:vint1[1]];
##       index_array[,4] = vint7[1:vint1[1]];
## 
##   // Initialize and compute log-likelihood
##     real lp = 0;
##     for (i in 1:vint1[1]) {
##       array[vint2[i]] int indices = index_array[i, 1:vint2[i]];
##       if (vint3[i] == 1) {
##         lp += bernoulli_logit_lpmf(1 | occ[i]);
##         lp += bernoulli_logit_lpmf(y[indices] | mu[indices]);
##       }
##       if (vint3[i] == 0) {
##         lp += log_sum_exp(bernoulli_logit_lpmf(1 | occ[i]) + 
##                               sum(log1m_inv_logit(mu[indices])), bernoulli_logit_lpmf(0 | occ[i]));
##       }
##     }
##     return(lp);
##   }

In addition to the functions to generate this custom Stan code, the main workhorses in flocker are functions to pack and unpack data and linear predictors from the shape of the observational data to the flocker_data format (and back). For further details, check out flocker:::make_flocker_data_static (for packing) and flocker:::get_positions (for unpacking).

A note on performance

As noted, the flocker approach to fitting in brms contains one substantial redundancy, which is that the linear predictor for occupancy gets computed redundantly several-fold too many times, since it need be computed only once per closure-unit, whereas flocker computes it once per visit. In addition, it is not possible to use Stan’s reduce_sum functionality for within-chain parallelization of the computation of the linear predictors, since chunking up the data destroys the validity of the indexing (and requires a level of control that reduce_sum does not provide to ensure that no closure-units end up split across multiple chunks). Despite these disadvantages, we find that occupancy modeling with flocker is remarkably performant, in many cases outperforming our previous hand-coded Stan implementations of models for large datasets and comparing favorably to other contemporary packages for occupancy modeling.

Even fancier families

Flocker provides a set of families that are more involved still. The first are the multi-season families, which group closure-units into series within which the occupancy state changes via modeled colonization and extinction dynamics. The second are data-augmented multi-species models, in which closure-units are grouped within species whose presence in the community (and thus availability for detection) is modeled explicitly.

The multi-species format

For we fit multi-species models via a hidden Markov model (HMM) approach to the likelihood. This vignette does not cover the implementation of that likelihood in detail–just the necessary data that we need to send to the unlooped likelihood function. Suppose the data contain \(S\) distinct series (i.e. distinct hidden Markov sequences), \(U\) closure-units (i.e. the sum over series of the number of timesteps per series). The data are ordered so that the first \(S\) rows correspond to the first repeat visit to the first timestep of all series (or to a ghost row if a given series has no visits in the first timestep), and the first \(U\) rows correspond to the first repeat visit to each closure-unit (i.e. timestep, or a ghost row if a given timestep contains no visits).

We pass:

Thus, we pass a number of vint terms equal to four plus the maximum number of timesteps in any series plus the maximum number of visits in any timestep. Here’s Stan code to decode this format and compute the likelihood for 5 timesteps with a maximum of 4 repeat visits, in this case for the colonization-extinction flavor of multispecies model. Note that this likelihood includes custom functions that flocker defines elsewhere and passes to the custom family via stanvars:

cat(flocker:::make_occupancy_multi_colex_lpmf(4, 5))
##   real occupancy_multi_colex_lpmf(
##     array[] int y, // detection data
##     vector mu, // linear predictor for detection
##     vector occ, // linear predictor for initial occupancy. Elements after vint1[1] irrelevant.
##     vector colo, // linear predictor for colonization. Elements after vint2[1] irrelevant.
##     vector ex, // linear predictor for extinction. Elements after vint2[1] irrelevant.
##     array[] int vint1, // # of series (# of HMMs). Elements after 1 irrelevant.
##     array[] int vint2, // # units (series-years). Elements after 1 irrelevant.
##     array[] int vint3, // # years per series. Elements after vint1[1] irrelevant.
##     array[] int vint4, // # sampling events per unit (n_rep). Elements after vint2[1] irrelevant.
##     array[] int vint5, // Indicator for > 0 detections (Q). Elements after vint2[1] irrelevant.
##   
##   // indices for jth unit (first rep) for each series. Elements after vint1[1] irrelevant.
##     array[] int vint6,
##     array[] int vint7,
##     array[] int vint8,
##     array[] int vint9,
##     array[] int vint10,
## 
## // indices for jth repeated sampling event to each unit (elements after vint2[1] irrelevant):
##     array[] int vint11,
##     array[] int vint12,
##     array[] int vint13,
##     array[] int vint14
## ) {
##   // Create array of the unit indices that correspond to each series.
##     array[vint1[1], 5] int unit_index_array;
##       unit_index_array[,1] = vint6[1:vint1[1]];
##       unit_index_array[,2] = vint7[1:vint1[1]];
##       unit_index_array[,3] = vint8[1:vint1[1]];
##       unit_index_array[,4] = vint9[1:vint1[1]];
##       unit_index_array[,5] = vint10[1:vint1[1]];
## 
## 
##   // Create array of the rep indices that correspond to each unit.
##     array[vint2[1], 4] int visit_index_array;
##       visit_index_array[,1] = vint11[1:vint2[1]];
##       visit_index_array[,2] = vint12[1:vint2[1]];
##       visit_index_array[,3] = vint13[1:vint2[1]];
##       visit_index_array[,4] = vint14[1:vint2[1]];
## 
##   // Initialize and compute log-likelihood
##     real lp = 0;
##     for (i in 1:vint1[1]) {
##       int n_year = vint3[i];
##       array[n_year] int Q = vint5[unit_index_array[i,1:n_year]];
##       array[n_year] int n_obs = vint4[unit_index_array[i,1:n_year]];
##       int max_obs = max(n_obs);
##       array[n_year, max_obs] int y_i;
##       real occ_i = occ[unit_index_array[i,1]];
##       vector[n_year] colo_i = to_vector(colo[unit_index_array[i,1:n_year]]);
##       vector[n_year] ex_i = to_vector(ex[unit_index_array[i,1:n_year]]);
##       array[n_year] row_vector[max_obs] det_i;
##       
##       for (j in 1:n_year) {
##         if (n_obs[j] > 0) {
##           y_i[j, 1:n_obs[j]] = y[visit_index_array[unit_index_array[i, j], 1:n_obs[j]]];
##           det_i[j, 1:n_obs[j]] = to_row_vector(mu[visit_index_array[unit_index_array[i, j], 1:n_obs[j]]]);
##         }
##       }
##       lp += forward_colex(n_year, Q, n_obs, y_i, occ_i, colo_i, ex_i, det_i);
##     }
##     return(lp);
##   }

The data augmented format

For the data augmented model, suppose that the dataset contains \(I\) sites, up to \(J\) visits per site, and \(K\) species (including the data-augmented pseudospecies). The data are ordered so that the first \(I \times K\) rows each represent the first visit to each closure-unit (species \(\times\) site). Then we pass auxiliary terms including:

Thus, we pass a number of vint terms equal to six plus the maximum number of visits at any site. Here’s Stan code to decode this format and compute the likelihood a dataset with a maximum of 4 repeat visits.

cat(flocker:::make_occupancy_augmented_lpmf(4))
##   real occupancy_augmented_lpmf(
##     array[] int y, // detection data
##     vector mu, // lin pred for detection
##     vector occ, // lin pred for occupancy. Elements after vint1[1] irrelevant.
##     vector Omega, // lin pred for availability.  Elements after 1 irrelevant.
##     array[] int vint1, // # units (n_unit). Elements after 1 irrelevant.
##     array[] int vint2, // # sampling events per unit (n_rep). Elements after vint1[1] irrelevant.
##     array[] int vint3, // Indicator for > 0 detections (Q). Elements after vint1[1] irrelevant.
##     
##     array[] int vint4, // # species (observed + augmented). Elements after 1 irrelevant.
##     array[] int vint5, // Indicator for species was observed.  Elements after vint4[1] irrelevant
##     
##     array[] int vint6, // species
##   
##   // indices for jth repeated sampling event to each unit (elements after vint1[1] irrelevant):
##     array[] int vint7,
##     array[] int vint8,
##     array[] int vint9,
##     array[] int vint10
## ) {
##   // Create array of the rep indices that correspond to each unit.
##     array[vint1[1], 4] int index_array;
##       index_array[,1] = vint7[1:vint1[1]];
##       index_array[,2] = vint8[1:vint1[1]];
##       index_array[,3] = vint9[1:vint1[1]];
##       index_array[,4] = vint10[1:vint1[1]];
## 
##   // Initialize and compute log-likelihood
##     real lp = 0;
##     
##     for (sp in 1:vint4[1]) {
##       real lp_s = 0;
##       if (vint5[sp] == 1) {
##         for (i in 1:vint1[1]) {
##           if (vint6[i] == sp) {
##             array[vint2[i]] int indices = index_array[i, 1:vint2[i]];
##             if (vint3[i] == 1) {
##               lp_s += bernoulli_logit_lpmf(1 | occ[i]);
##               lp_s += bernoulli_logit_lpmf(y[indices] | mu[indices]);
##             }
##             if (vint3[i] == 0) {
##               lp_s += log_sum_exp(bernoulli_logit_lpmf(1 | occ[i]) + 
##                                     sum(log1m_inv_logit(mu[indices])), bernoulli_logit_lpmf(0 | occ[i]));
##             }
##           }
##         }
##         lp += log_inv_logit(Omega[1]) + lp_s;
##       } else {
##         for (i in 1:vint1[1]) {
##           if (vint6[i] == sp) {
##             array[vint2[i]] int indices = index_array[i, 1:vint2[i]];
##             lp_s += log_sum_exp(bernoulli_logit_lpmf(1 | occ[i]) + 
##                                   sum(log1m_inv_logit(mu[indices])), bernoulli_logit_lpmf(0 | occ[i]));
##           }
##         }
##         lp += log_sum_exp(log1m_inv_logit(Omega[1]), log_inv_logit(Omega[1]) + lp_s);  
##       }
##     }
##     return(lp);
##   }