# em.R
# EM algorithms for alpha-stable mixture estimation

# =============================================================================
# SIMPLE EM INITIALIZATION
# =============================================================================


#' Simple 2-component EM using ECF initialization
#'
#' Initializes EM using k-means clustering and ECF-based parameter estimation.
#'
#' @param X Numeric vector of observations.
#' @param max_iter Integer. Maximum number of iterations.
#'
#' @return List with lambda1 and estimated parameters for both components.
#' @export
simple_em_real <- function(X, max_iter = 10) {
  # K-means clustering for initialization
  clusters <- kmeans(X, centers = 2)$cluster - 1  # Convert to 0-based indexing
  lambda1 <- mean(clusters == 0)

  ests <- list()
  for (k in 0:1) {
    Xi <- X[clusters == k]
    ests[[k + 1]] <- ecf_estimate_all(Xi)  # Assuming this function exists
  }

  return(list(
    lambda1 = lambda1,
    params1 = ests[[1]],
    params2 = ests[[2]]
  ))
}

# =============================================================================
# EM algorithm for alpha-stable mixture
# =============================================================================



#' EM algorithm for alpha-stable mixture
#'
#' Estimates parameters of an alpha-stable mixture using EM with optional random initialization.
#'
#' @param data Numeric vector of observations.
#' @param n_components Integer. Number of mixture components.
#' @param max_iter Integer. Maximum number of EM iterations.
#' @param tol Numeric. Convergence tolerance.
#' @param random_init Logical. Whether to use random initialization.
#' @param debug Logical. Whether to print debug information.
#'
#' @return List with estimated weights and parameters (alpha, beta, gamma, delta).
#' @export
em_alpha_stable <- function(data, n_components = 2, max_iter = 100, tol = 1e-4,
                            random_init = TRUE, debug = TRUE) {
  N <- length(data)

  # Initialize parameters
  if (random_init) {
    weights <- rep(1/n_components, n_components)
    alphas <- runif(n_components, 1.2, 1.8)
    betas <- runif(n_components, -0.5, 0.5)
    gammas <- runif(n_components, 0.5, 2.0)
    deltas <- runif(n_components, min(data), max(data))
  } else {
    weights <- rep(1/n_components, n_components)
    alphas <- rep(1.8, n_components)
    betas <- rep(0, n_components)
    gammas <- rep(sd(data)/2, n_components)
    deltas <- seq(min(data), max(data), length.out = n_components)
  }

  log_likelihood_old <- -Inf

  for (iteration in 1:max_iter) {
    # E-step: Calculate responsibilities
    responsibilities <- matrix(0, N, n_components)

    for (k in 1:n_components) {
      pdf_vals <- dstable(data, alpha = alphas[k], beta = betas[k],
                          gamma = gammas[k], delta = deltas[k], pm = 1)
      pdf_vals <- pmax(pdf_vals, 1e-300)
      responsibilities[, k] <- weights[k] * pdf_vals
    }

    # Normalize responsibilities
    sum_responsibilities <- rowSums(responsibilities) + 1e-12
    responsibilities <- responsibilities / sum_responsibilities

    # M-step: Update parameters
    for (k in 1:n_components) {
      r <- responsibilities[, k]
      Nk <- sum(r)

      if (Nk < 1e-8) next

      weights[k] <- Nk / N

      # Create weighted sample for parameter estimation
      expanded_data <- rep(data, round(r / sum(r) * N))

      if (length(expanded_data) > 10 && sd(expanded_data) > 1e-8) {
        tryCatch({
          params <- stable_fit_init(expanded_data)
          alphas[k] <- params$alpha
          betas[k] <- params$beta
          gammas[k] <- params$gamma
          deltas[k] <- params$delta
        }, error = function(e) {
          if (debug) message("Fit failed for component", k, ":", e$message, "\n")
        })
      }
    }

    # Calculate log-likelihood
    likelihood <- matrix(0, N, n_components)
    for (k in 1:n_components) {
      likelihood[, k] <- weights[k] * dstable(data, alpha = alphas[k], beta = betas[k],
                                              gamma = gammas[k], delta = deltas[k], pm = 1)
    }

    total_likelihood <- sum(log(rowSums(likelihood) + 1e-12))

    if (debug) {
      message(sprintf("[Iteration %d] Log-Likelihood: %.6f\n", iteration, total_likelihood))
    }

    # Check convergence
    if (abs(total_likelihood - log_likelihood_old) < tol) {
      if (debug) message(sprintf("Converged after %d iterations.\n", iteration))
      break
    }

    log_likelihood_old <- total_likelihood
  }

  return(list(
    weights = weights,
    alphas = alphas,
    betas = betas,
    gammas = gammas,
    deltas = deltas
  ))
}

# =============================================================================
# EM WITH CUSTOM ESTIMATOR FUNCTION
# =============================================================================


#' EM algorithm for alpha-stable mixture using a custom estimator
#'
#' Performs EM estimation using a user-defined parameter estimator and ECF frequencies.
#'
#' @param data Numeric vector of observations.
#' @param u Numeric vector of frequency values for ECF.
#' @param estimator_func Function to estimate stable parameters.
#' @param max_iter Integer. Maximum number of EM iterations.
#' @param epsilon Numeric. Convergence threshold on log-likelihood.
#'
#' @return List with estimated weights, parameters, and log-likelihood.
#' @export
em_stable_mixture <- function(data, u, estimator_func, max_iter = 300, epsilon = 1e-3) {
  S <- data
  n <- length(S)

  # Initial clustering
  kmeans_result <- kmeans(matrix(S, ncol = 1), centers = 2)
  labels <- kmeans_result$cluster - 1  # Convert to 0-based

  # Initial parameter estimation
  S1 <- estimator_func(S[labels == 0], u)
  S2 <- estimator_func(S[labels == 1], u)
  w <- mean(labels == 0)

  p1 <- c(S1$alpha, S1$beta, ensure_positive_scale(S1$delta), ensure_positive_scale(S1$gamma))
  p2 <- c(S2$alpha, S2$beta, ensure_positive_scale(S2$delta), ensure_positive_scale(S2$gamma))

  LV <- -Inf

  for (s in 1:max_iter) {
    cc <- integer(n)

    # E-step: probabilistic assignment
    for (i in 1:n) {
      tryCatch({
        v1 <- log(w) + log(dstable(S[i], alpha = p1[1], beta = p1[2],
                                   gamma = p1[3], delta = p1[4], pm = 1) + 1e-10)
        v2 <- log(1 - w) + log(dstable(S[i], alpha = p2[1], beta = p2[2],
                                       gamma = p2[3], delta = p2[4], pm = 1) + 1e-10)
        v <- exp(c(v1, v2) - max(v1, v2))
        v <- v / sum(v)
        v <- pmax(pmin(v, 1), 0)
      }, error = function(e) {
        v <<- c(0.5, 0.5)
      })
      cc[i] <- sample(c(0, 1), 1, prob = v)
    }

    w <- max(0.01, min(0.99, mean(cc == 0)))

    # M-step: re-estimate parameters
    if (sum(cc == 0) >= 2) {
      tryCatch({
        L1 <- estimator_func(S[cc == 0], u)
        if (all(is.finite(c(L1$alpha, L1$beta, L1$delta, L1$gamma)))) {
          p1 <- c(L1$alpha, L1$beta, ensure_positive_scale(L1$delta), ensure_positive_scale(L1$gamma))
        }
      }, error = function(e) {})
    }

    if (sum(cc == 1) >= 2) {
      tryCatch({
        L2 <- estimator_func(S[cc == 1], u)
        if (all(is.finite(c(L2$alpha, L2$beta, L2$delta, L2$gamma)))) {
          p2 <- c(L2$alpha, L2$beta, ensure_positive_scale(L2$delta), ensure_positive_scale(L2$gamma))
        }
      }, error = function(e) {})
    }

    # Log-likelihood
    pdf1 <- dstable(S, alpha = p1[1], beta = p1[2], gamma = p1[3], delta = p1[4], pm = 1)
    pdf2 <- dstable(S, alpha = p2[1], beta = p2[2], gamma = p2[3], delta = p2[4], pm = 1)
    LVn <- sum(log(w * pdf1 + (1 - w) * pdf2))

    if (is.finite(LVn) && is.finite(LV) && abs(LVn) > 1e-10) {
      relative_change <- abs(LVn - LV) / abs(LVn)
      if (is.finite(relative_change) && relative_change < epsilon) {
        break
      }
    } else if (!is.finite(LVn) || !is.finite(LV)) {
      warning("Non-finite log-likelihood detected. Stopping iteration.")
      break
    }
    LV <- LVn

    message(sprintf("Iteration %d, Log-likelihood: %.6f\n", s, LVn))
  }

  return(list(
    weights = w,
    params1 = p1,
    params2 = p2,
    log_likelihood = LV
  ))
}

# =============================================================================
# EM FOR TWO-COMPONENT ALPHA-STABLE MIXTURE (MLE-BASED)
# =============================================================================


#' EM algorithm for two-component alpha-stable mixture using MLE
#'
#' Estimates parameters of a two-component alpha-stable mixture using MLE and EM.
#'
#' @param data Numeric vector of observations.
#' @param max_iter Integer. Maximum number of EM iterations.
#' @param tol Numeric. Convergence tolerance on log-likelihood.
#' @param return_trace Logical. Whether to return trace of responsibilities and log-likelihoods.
#'
#' @return List with estimated parameters and optional trace.
#' @export
em_fit_alpha_stable_mixture <- function(data, max_iter = 200, tol = 1e-4, return_trace = FALSE) {
  if (length(data) < 2) {
    stop("Input data must contain at least two points.")
  }

  # Initial clustering
  kmeans_result <- kmeans(matrix(data, ncol = 1), centers = 2)
  labels <- kmeans_result$cluster - 1  # Convert to 0-based

  # Initial parameter estimation
  params1 <- fit_alpha_stable_mle(data[labels == 0])
  params2 <- fit_alpha_stable_mle(data[labels == 1])
  w <- mean(labels == 0)

  responsibilities_trace <- list()
  log_likelihoods <- numeric()
  prev_log_likelihood <- NULL

  for (iteration in 1:max_iter) {
    # E-step
    pdf1 <- pmax(dstable(data, alpha = params1[1], beta = params1[2],
                         gamma = params1[3], delta = params1[4], pm = 1), 1e-300)
    pdf2 <- pmax(dstable(data, alpha = params2[1], beta = params2[2],
                         gamma = params2[3], delta = params2[4], pm = 1), 1e-300)

    responsibilities <- cbind(w * pdf1, (1 - w) * pdf2)
    responsibilities <- responsibilities / rowSums(responsibilities)

    if (return_trace) {
      responsibilities_trace[[iteration]] <- responsibilities
    }

    # M-step
    labels <- apply(responsibilities, 1, which.max) - 1
    w <- mean(labels == 0)

    if (sum(labels == 0) >= 2) {
      params1 <- fit_alpha_stable_mle(data[labels == 0])
    }
    if (sum(labels == 1) >= 2) {
      params2 <- fit_alpha_stable_mle(data[labels == 1])
    }

    # Log-likelihood
    total_pdf <- w * pdf1 + (1 - w) * pdf2
    new_log_likelihood <- sum(log(total_pdf))

    if (return_trace) {
      log_likelihoods <- c(log_likelihoods, new_log_likelihood)
    }

    message(sprintf("Iteration %d: Log-Likelihood = %.6f\n", iteration, new_log_likelihood))

    if (!is.null(prev_log_likelihood)) {
      if (abs(new_log_likelihood - prev_log_likelihood) / (abs(new_log_likelihood) + 1e-12) < tol) {
        message("Converged.\n")
        break
      }
    }

    prev_log_likelihood <- new_log_likelihood
  }

  if (return_trace) {
    return(list(
      params1 = params1,
      params2 = params2,
      w = w,
      responsibilities_trace = responsibilities_trace,
      log_likelihoods = log_likelihoods
    ))
  } else {
    return(list(
      params1 = params1,
      params2 = params2,
      w = w
    ))
  }
}

# =============================================================================
# EM FOR GAUSSIAN MIXTURE (2 COMPONENTS)
# =============================================================================


#' EM algorithm for two-component Gaussian mixture
#'
#' Estimates parameters of a Gaussian mixture using the EM algorithm.
#'
#' @param data Numeric vector of observations.
#' @param max_iter Integer. Maximum number of EM iterations.
#' @param tol Numeric. Convergence tolerance.
#'
#' @return A list containing estimated mixture parameters: \code{pi}, \code{mu1}, \code{mu2}, \code{sigma1}, \code{sigma2}.
#' @importFrom stats dnorm
#' @export
em_estimation_mixture <- function(data, max_iter = 100, tol = 1e-6) {
  if (length(data) < 2) {
    stop("Data must contain at least two points.")
  }

  n <- length(data)
  pi <- 0.5
  mu1 <- min(data)
  mu2 <- max(data)
  sigma1 <- 1.0
  sigma2 <- 1.0

  for (i in 1:max_iter) {
    # E-step
    resp1 <- pi * dnorm(data, mu1, sigma1)
    resp2 <- (1 - pi) * dnorm(data, mu2, sigma2)
    sum_resp <- resp1 + resp2

    w1 <- resp1 / sum_resp
    w2 <- resp2 / sum_resp

    # M-step
    pi_new <- mean(w1)
    mu1_new <- sum(w1 * data) / sum(w1)
    mu2_new <- sum(w2 * data) / sum(w2)
    sigma1_new <- sqrt(sum(w1 * (data - mu1_new)^2) / sum(w1))
    sigma2_new <- sqrt(sum(w2 * (data - mu2_new)^2) / sum(w2))

    # Convergence check
    if (abs(mu1 - mu1_new) < tol && abs(mu2 - mu2_new) < tol) {
      break
    }

    pi <- pi_new
    mu1 <- mu1_new
    mu2 <- mu2_new
    sigma1 <- sigma1_new
    sigma2 <- sigma2_new
  }

  return(list(
    pi = pi,
    mu1 = mu1,
    sigma1 = sigma1,
    mu2 = mu2,
    sigma2 = sigma2
  ))
}



