
#' @title Local Laplace VB
#' @description
#'  A variational Bayesian algorithm, based on the Laplace Spike-and-Slab prior, is tailored for
#'    multi-source heterogeneous models and focuses on variable selection exclusively for the
#'    homogeneous covariates.
#'
#' @param X Homogeneous covariates
#' @param Z Heterogeneous covariates
#' @param Y Response covariates
#' @param max_iter Maximum number of iterations, Defaut:1000
#' @param tol Algorithm convergence tolerance, Defaut:1e-6
#' @param a A prior of Beta distribution, Defaut:1
#' @param b A prior of Beta distribution, Defaut:10
#' @param lambda A prior of Laplace distribution, Defaut:1
#'
#' @return The mean of the homogeneity coefficient:mu;
#'         The variance of homogeneity coefficient:sigma;
#'         Selection coefficient:gamma;
#'         Mean and covariance of heterogeneity coefficients:m, s2.

vb_lap_local <- function(X, Z, Y, max_iter=1000, tol=1e-6, a=1, b=10, lambda=1) {

  fn <- function(coef_sq, coef_lin, lambda) {
    return(function(x) {
      mu <- x[1]
      sigma <- x[2]

      exp_term <- lambda * (1 / sqrt(2)) * (2 / sqrt(pi)) * exp(-0.5 * (mu / sigma)^2)
      erf_term <- lambda * pracma::erf(sqrt(0.5) * mu / sigma)

      gradient <- numeric(2)
      gradient[1] <- 2 * coef_sq * mu + erf_term + coef_lin
      gradient[2] <- 2 * coef_sq * sigma + exp_term - 1 / sigma

      obj_value <- (erf_term + coef_lin) * mu +
        coef_sq * (mu^2 + sigma^2) -
        log(abs(sigma)) + sigma * exp_term

      return(list(value = obj_value, gradient = gradient))
    })
  }


  n <- dim(X)[1]
  p <- dim(X)[2]
  K <- dim(X)[3]
  q <- dim(Z)[2]

  noisy_sd <-0
  for (k in 1:K) {
    noisy_sd <- noisy_sd + selectiveInference::estimateSigma(cbind(X[,,k],Z[,,k]),Y[,k])$sigmahat
  }
  noisy_sd <- noisy_sd/K
  X <- X/noisy_sd
  Y <- Y/noisy_sd
  Z <- Z/noisy_sd

  mu <- matrix(rep(0,p),nrow=p,ncol=1)
  sigma <- matrix(rep(1,p),nrow=p,ncol=1)
  gamma <- matrix(rep(0.5,p),nrow=p,ncol=1)
  s2 <- array(data=diag(q),dim=c(q,q,K))
  m <- matrix(1,nrow=q,ncol=K)



  old_entr <- entropy(gamma)

  all_X <- X[,,1]
  all_Z <- Z[,,1]
  all_Zm <- Z[,,1] %*% m[,1]
  all_Y <- Y[,1]

  for (k in 2:K) {
    all_X <- rbind(all_X,X[,,k])
    all_Z <- rbind(all_Z,Z[,,k])
    all_Y <- c(all_Y,Y[,k])
    all_Zm <- c(all_Zm,Z[,,k] %*% m[,k])
  }

  YX_vec <- t(all_Y-all_Zm) %*% all_X
  half_diag <- 0.5 * gram_diag(all_X)
  approx_mean <- gamma * mu
  X_appm <- all_X %*% approx_mean

  exit_loop <- FALSE

  const_lodds <- (log(a) - log(b)) + 0.5
  const_lodds <- const_lodds + 0.5 * log(pi) + log(lambda) - 0.5 * log(2)
  for (i in 1:max_iter) {
    all_Zm <- Z[,,1] %*% m[,1]
    for (k in 2:K) {
      all_Zm <- c(all_Zm,Z[,,k] %*% m[,k])
    }
    YX_vec <- t(all_Y-all_Zm) %*% all_X
    for (j in 1:length(mu)) {

      X_appm <- X_appm - approx_mean[j] * all_X[, j]

      obj_fn <- fn(half_diag[j], as.numeric(all_X[, j] %*% X_appm - YX_vec[j]), lambda)
      x <- c(mu[j], sigma[j])

      optim_result <- try(optim(par = x, fn = function(par) obj_fn(par)$value,
                                gr = function(par) obj_fn(par)$gradient, method = "L-BFGS-B"),silent=TRUE)

      if (inherits(optim_result, "try-error")) {
        exit_loop <- TRUE
        break
      }

      mu[j] <- optim_result$par[1]
      sigma[j] <- optim_result$par[2]
      gamma[j] <- sigmoid(const_lodds - optim_result$value)

      approx_mean[j] <- gamma[j] * mu[j]
      X_appm <- X_appm + approx_mean[j] * all_X[, j]
    }

    for (k in 1:K){
      s2[,,k] <- solve(t(Z[,,k]) %*% Z[,,k]+diag(ncol(Z[,,k])))
      m[,k] <- s2[,,k] %*% t(t(Y[,k] - X[,,k] %*% (mu*gamma)) %*% Z[,,k])
    }

    if (exit_loop == TRUE && all(gamma==gamma[1])){
      return(NA)
      break
    }

    new_entr <- entropy(gamma)
    if (max(abs(new_entr - old_entr)) <= tol) {
      break
    } else {
      old_entr <- new_entr
    }
  }
  return(list(mu = mu, gamma = gamma, m = m, sigma=sigma, s2=s2))
}
