mait <- function(Y, x, tol = 1e-8, maxit = 50000, alpha = 0.01) {
  p <- dim(x)[2]
  clr <- function(z) {
    lz <- log(z)
    lz - mean(lz)
  }
  tx <- t(x)
  clY <- log(Y)
  clY <- clY - mean(clY)

  d <- dim(Y)[2]
  B <- matrix(0, p, d)
  obj <- iterations <- numeric(d)

  for ( j in 1:d ) {
    cly <- clY[, j]
    be <- rep(1/p, p)
    for ( iter in 1:maxit ) {
      y_hat <- drop(x %*% be)
      res <- cly - clr(y_hat)
      grad <- -2 * tx %*% (res / y_hat)
      # Fixed point update: b = b * exp(-alpha * centered_gradient)
      grad_centered <- grad - mean(grad)
      be_new <- be * exp(-alpha * grad_centered)
      be_new <- be_new / sum(be_new)
      if ( max(abs(be_new - be) ) < tol) {
        break
      }
      be <- be_new
    }
    y_hat_final <- x %*% be
    obj[j] <- sum( ( cly - clr(y_hat_final) )^2 )
    B[, j] <- round(be, 12)
    iterations[j] <- iter
  }

  list(coefficients = round(B, 12), value = obj, iterations = iter)
}

