// [[Rcpp::depends(RcppArmadillo)]]
#include <RcppArmadillo.h>
#include <Rcpp.h>
#include <math.h>
using namespace Rcpp;

List my_eigen(arma::mat A){
  Environment pkg = Environment::namespace_env("base") ;
  Function f = pkg["eigen"] ;
  SEXP Eigen = f(Named("x") = A,
                 _["symmetric"] = true,
                 _["only.values"] = false,
                 _["EISPACK"] = false) ;
  return as<List>(Eigen) ;
}

List fastGauss_Hermite(int points){
  Environment pkg = Environment::namespace_env("fastGHQuad") ;
  Function f = pkg["gaussHermiteData"] ;
  List Gauss = f(Named("n") = points) ;
  arma::vec nodes  = Gauss["x"] ;
  arma::vec weights = Gauss["w"];
  
  return(List::create(_["nodes"] = arma::conv_to<arma::mat>::from(nodes * sqrt(2.)), 
                     _["weights"] = weights/sqrt(arma::datum::pi))) ;
}

long power(int n, int q) {
  long result = 1;
  for (int i = 0; i < q; ++i) {
    result *= n;
  }
  return result;
}

arma::umat expand_grid(int n, int dm) {                               // expand.grid(): necessary for mutivariate Gauss-Hermite quadratures
  long nr = power(n, dm) ;
  arma::umat idx(nr, dm) ;
  arma::uvec indx = arma::regspace<arma::uvec>(0, n-1) ;
  
  nr = nr/n ;
  idx.col(0) = repmat(indx, nr, 1) ;
  int nn = 1 ;
  for(int i = 1; i < dm ; i++){
    nn = nn * n ;
    nr = nr / n ;
    arma::uvec tmp = repelem(indx, nn, 1) ;
    idx.col(i) = repmat(tmp, nr, 1) ;
  }
  return idx ;
}

// [[Rcpp::export]]
List mgauss_hermite(int n, arma::vec mu, arma::mat Sigma, double prune = 0.){    //  Mutivariate Gauss-Hermite quadrature
  
  /*  compute multivariate Gaussian quadrature points
   n     - number of points each dimension before pruning
   mu    - mean vector
   Sigma - covariance matrix
   prune - NULL - no pruning; [0-1] - fraction to prune            */
  
  if(!Sigma.is_square()){
    stop("Sigma is not a square matrix.") ;
  }
  if(Sigma.n_rows != mu.n_elem){
    stop("mu and Sigma have nonconformable dimensions") ;
  }
  
  List gh = fastGauss_Hermite(n) ;
  arma::vec weights = gh["weights"] ;
  arma::mat nodes = gh["nodes"] ;
  
  arma::uvec ii = {0} ;
  int dm = mu.n_elem ;
  long nr = power(n, dm) ;
  
  // idx grows exponentially in n and dm
  
  arma::uvec idx = arma::vectorise(expand_grid(n, dm)) ;
  arma::mat pts(nodes.submat(idx, ii)) ;
  pts.reshape(nr, dm) ;
  
  arma::mat wtsM = (weights(idx)) ;
  wtsM.reshape(nr, dm) ;
  arma::vec wts = arma::prod(wtsM,1);
  
  if(prune > 0.) {
    arma::vec p = {prune}; 
    arma::vec qwt = arma::quantile(wts, p) ;
    idx = arma::find(wts > qwt(0));
    pts = pts.rows(idx) ;
    wts = wts(idx) ;
  }                                                          
  
  // rotate, scale, translate points
  
  List spd = my_eigen(Sigma) ;
  arma::vec eigval = spd["values"];
  arma::mat eigvec = spd["vectors"];
  
  arma::mat rot = eigvec * diagmat(sqrt(eigval)) ;
  pts = rot * pts.t() ;
  
  for(unsigned i = 0 ; i < pts.n_cols ; i++){ // pts.n_cols ; i++){
    pts.col(i) = pts.col(i) + mu ;
  }
  return(List::create(_["nodes"] = pts.t(),
                      _["weights"] = arma::vectorise(wts))) ;
}

//
// This function computes the $k$-th order moment for a univariate truncated 
//   normal distribution.
//
// Using the moment generating function, \code{utrunmnt} can compute any order 
//   of moment for a truncated normal distribution. 
//   
// @parm k order of moment.
// @parm mu mean of parent normal distribution.
// @parm sd standard deviation of parent normal distribution.
// @parm a lower bound.
// @parm b upper bound.
// [[Rcpp::export]]

double univariatetrunmnt(int k, double mu, double a, double b, double sd){
  double moment, prob, beta, pbeta, dbeta, alpha, palpha, dalpha, L0, L1, Li ;
  
  if(std::isinf(a)){
    alpha = -10. ;
    palpha = 0. ;
    dalpha = 0. ;
  } else{
    alpha = (a - mu) / sd ;
    palpha = arma::normcdf(alpha) ;
    dalpha = arma::normpdf(alpha) ;
  }
  if(std::isinf(b)){
    beta = 10. ;
    pbeta = 1. ;
    dbeta = 0. ;
  } else{
    beta = (b - mu) / sd ;
    pbeta = arma::normcdf(beta) ;
    dbeta = arma::normpdf(beta) ;
  }
  
  prob = pbeta - palpha ;
  
  if (prob == 0.) stop("The interval has zero probabiliaty") ;
  if(k == 0) return(1.) ;
  
// Formula of John Burkardt's "The Truncated Normal Distrbution" in page 25.  

  double betapdf  = dbeta ;
  double alphapdf = dalpha ;
  double sdi = sd ;

  L0 = 1. ;
  L1 =  (alphapdf - betapdf)/prob ; 
  moment = mu + k * sdi * L1 ;

  for(int i = 2; i<= k ; i++){
    sdi = sdi * sd ;
    betapdf  = beta * betapdf ;
    alphapdf = alpha* alphapdf ;
    Li = (alphapdf - betapdf)/prob + (i - 1) * L0 ;
    L0 = L1 ;
    L1 = Li ;
    moment = moment * mu + R::choose(k, i) * sdi * Li ;
  }
  return(moment) ;
}

// [[Rcpp::export]]
double getprobab(arma::vec mu, arma::vec sigmae, arma::vec lower, arma::vec upper, arma::mat Z, List GH){
  double fxv, mu_j, prob = 0. ;
  arma::mat zip_alpha ;
  int k = mu.n_elem ;
  arma::vec weights = GH["weights"] ;
  arma::mat nodes = GH["nodes"] ;
  int npoints = weights.n_elem ;

  for(int i = 0 ; i < npoints ; i++){
    fxv = 1. ;
    for(int j = 0 ; j < k; j++){
      zip_alpha = Z.row(j) * nodes.row(i).t() ;
      mu_j = mu(j) + zip_alpha(0,0) ;
      fxv = fxv * ( arma::normcdf(upper(j), mu_j, sigmae(j)) - arma::normcdf(lower(j), mu_j, sigmae(j))) ;
    }
    prob = prob + weights(i) * fxv ;
  }

  return (prob) ;
}  

// Compute E(Y^lambda | a < Y < b )
// [[Rcpp::export]]
double getExpect(arma::uvec lambda, arma::vec lower, arma::vec upper, arma::vec mu, 
                 arma::vec sigmae, double probab, arma::mat Z, arma::mat D, 
                 arma::vec weights, arma::mat nodes){
  
  double expectation = 0. ;
  double fxv, xval, probi, mu_j ;
  arma::mat zip_alpha ;
  int m = mu.n_elem ;
  int npoints = weights.n_elem ;
  
  for(int i = 0 ; i < npoints ; i++){
    fxv = 1. ;
    for(int j = 0 ; j < m ; j++){
      zip_alpha = Z.row(j) * nodes.row(i).t() ;
      mu_j = mu(j) + zip_alpha(0,0) ;
      probi = arma::normcdf(upper(j), mu_j, sigmae(j)) - arma::normcdf(lower(j), mu_j, sigmae(j)) ;
      if(probi == 0. || lambda(j) == 0) xval = probi ;
      else {
        xval = probi * univariatetrunmnt(lambda(j), mu_j, lower(j), upper(j), sigmae(j)) ;
      }
      fxv = fxv * xval ;
    }
    expectation = expectation + weights(i) * fxv ;
  }
  expectation = expectation/probab ;  
  return(expectation) ;
}

// Compute the truncated mean and variance-covariance matrix
// [[Rcpp::export]]
List getMeanVar(arma::vec lower, arma::vec upper, arma::vec mu, 
                arma::vec sigmae, double probab, arma::mat Z, arma::mat D, 
                arma::vec weights, arma::mat nodes){
  
  int k = mu.n_elem ;
  arma::vec mean(k, arma::fill::zeros) ;
  arma::uvec lambda(k, arma::fill::zeros) ;
  arma::mat var(k,k, arma::fill::zeros) ;
  
  for(int i = 0 ; i < k ; i++){
    lambda(i) = 2 ;
    var(i,i) = getExpect(lambda, lower, upper, mu, sigmae, probab, Z, D, weights, nodes) ;
    lambda(i) = 1 ;
    mean(i) = getExpect(lambda, lower, upper, mu, sigmae, probab, Z, D, weights, nodes) ;
    
    for(int j = i+1 ; j < k ; j++){
      lambda(j) = 1 ;
      var(i,j) = getExpect(lambda, lower, upper, mu, sigmae, probab, Z, D, weights, nodes) ;
      var(j,i) = var(i,j) ;
      lambda(j) = 0 ;
    }
    lambda.fill(0) ;
  }
  var = var - mean * mean.t() ;
  
  return(List::create(_["tmean"] = mean,
                      _["tvar"] = var)) ;
}           
