/*
 * Decompiled with CFR 0.152.
 */
package dr.inference.hmc;

import dr.inference.hmc.HessianWrtParameterProvider;
import dr.inference.model.GradientProvider;
import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.inference.operators.hmc.NumericalHessianFromGradient;
import dr.math.MultivariateFunction;
import dr.math.NumericalDerivative;
import dr.math.matrixAlgebra.Vector;
import dr.xml.Reportable;
import java.util.logging.Logger;

public interface GradientWrtParameterProvider {
    public static final Double TOLERANCE = 0.1;

    public Likelihood getLikelihood();

    public Parameter getParameter();

    public int getDimension();

    public double[] getGradientLogDensity();

    public static String makeReport(String string, double[] dArray, double[] dArray2, boolean bl, double d, double d2) throws MismatchException {
        StringBuilder stringBuilder = new StringBuilder(string);
        stringBuilder.append("analytic: ").append(new Vector(dArray));
        stringBuilder.append("\n");
        stringBuilder.append("numeric : ").append(new Vector(dArray2));
        stringBuilder.append("\n");
        if (bl) {
            for (int i = 0; i < dArray.length; ++i) {
                boolean bl2;
                double d3 = 2.0 * (dArray[i] - dArray2[i]) / (dArray[i] + dArray2[i]);
                boolean bl3 = bl2 = Math.abs(d3) > d && Math.abs(dArray[i]) > d2 && Math.abs(dArray2[i]) > d2 || (dArray[i] == 0.0 || dArray2[i] == 0.0) && Math.abs(dArray[i] + dArray2[i]) > d;
                if (!bl2) continue;
                stringBuilder.append("\nDifference @ ").append(i + 1).append(": ").append(dArray[i]).append(" ").append(dArray2[i]).append(" ").append(d3).append("\n");
                Logger.getLogger("dr.inference.hmc").info(stringBuilder.toString());
                throw new MismatchException();
            }
        }
        return stringBuilder.toString();
    }

    public static String getReportAndCheckForError(GradientWrtParameterProvider gradientWrtParameterProvider, double d, double d2, Double d3) {
        return GradientWrtParameterProvider.getReportAndCheckForError(gradientWrtParameterProvider, d, d2, d3, null);
    }

    public static String getReportAndCheckForError(GradientWrtParameterProvider gradientWrtParameterProvider, double d, double d2, Double d3, Double d4) {
        String string;
        try {
            string = new CheckGradientNumerically(gradientWrtParameterProvider, d, d2, d3, d4).getReport();
        }
        catch (MismatchException mismatchException) {
            String string2 = mismatchException.getMessage();
            if (string2 == null) {
                string2 = gradientWrtParameterProvider.getParameter().getParameterName();
            }
            if (string2 == null) {
                string2 = "Gradient check failure";
            }
            throw new RuntimeException(string2);
        }
        return string;
    }

    public static class MismatchException
    extends Exception {
    }

    public static class CheckGradientNumerically {
        private final GradientWrtParameterProvider provider;
        private final Parameter parameter;
        private final double lowerBound;
        private final double upperBound;
        private final boolean checkValues;
        private final double tolerance;
        private final double smallThreshold;
        private final MultivariateFunction numeric = new MultivariateFunction(){

            @Override
            public double evaluate(double[] dArray) {
                this.setParameter(dArray);
                return provider.getLikelihood().getLogLikelihood();
            }

            @Override
            public int getNumArguments() {
                return parameter.getDimension();
            }

            @Override
            public double getLowerBound(int n) {
                return lowerBound;
            }

            @Override
            public double getUpperBound(int n) {
                return upperBound;
            }
        };

        CheckGradientNumerically(GradientWrtParameterProvider gradientWrtParameterProvider, double d, double d2, Double d3, Double d4) {
            this.provider = gradientWrtParameterProvider;
            this.parameter = gradientWrtParameterProvider.getParameter();
            this.lowerBound = d;
            this.upperBound = d2;
            this.checkValues = d3 != null;
            this.tolerance = this.checkValues ? d3 : 0.0;
            this.smallThreshold = d4 != null ? d4 : 0.0;
        }

        private void setParameter(double[] dArray) {
            for (int i = 0; i < dArray.length; ++i) {
                this.parameter.setParameterValueQuietly(i, dArray[i]);
            }
            this.parameter.fireParameterChangedEvent();
        }

        public double[] getNumericalGradient() {
            double[] dArray = this.parameter.getParameterValues();
            double[] dArray2 = NumericalDerivative.gradient(this.numeric, this.parameter.getParameterValues());
            this.setParameter(dArray);
            return dArray2;
        }

        public String getReport() throws MismatchException {
            double[] dArray = this.provider.getGradientLogDensity();
            double[] dArray2 = this.getNumericalGradient();
            return GradientWrtParameterProvider.makeReport("Gradient\n", dArray, dArray2, this.checkValues, this.tolerance, this.smallThreshold);
        }
    }

    public static class ParameterWrapper
    implements GradientWrtParameterProvider,
    HessianWrtParameterProvider,
    Reportable {
        final GradientProvider provider;
        final Parameter parameter;
        final Likelihood likelihood;

        public ParameterWrapper(GradientProvider gradientProvider, Parameter parameter, Likelihood likelihood) {
            this.provider = gradientProvider;
            this.parameter = parameter;
            this.likelihood = likelihood;
        }

        @Override
        public Likelihood getLikelihood() {
            return this.likelihood;
        }

        @Override
        public Parameter getParameter() {
            return this.parameter;
        }

        @Override
        public int getDimension() {
            return this.parameter.getDimension();
        }

        @Override
        public double[] getGradientLogDensity() {
            return this.provider.getGradientLogDensity(this.parameter.getParameterValues());
        }

        @Override
        public double[] getDiagonalHessianLogDensity() {
            NumericalHessianFromGradient numericalHessianFromGradient = new NumericalHessianFromGradient(this);
            return numericalHessianFromGradient.getDiagonalHessianLogDensity();
        }

        @Override
        public double[][] getHessianLogDensity() {
            throw new RuntimeException("Not yet implemented");
        }

        @Override
        public String getReport() {
            return GradientWrtParameterProvider.getReportAndCheckForError(this, this.parameter.getBounds().getLowerLimit(0), this.parameter.getBounds().getUpperLimit(0), null);
        }
    }

    public static class Negative
    implements GradientWrtParameterProvider {
        private final GradientWrtParameterProvider provider;

        public Negative(GradientWrtParameterProvider gradientWrtParameterProvider) {
            this.provider = gradientWrtParameterProvider;
        }

        @Override
        public Likelihood getLikelihood() {
            return this.provider.getLikelihood();
        }

        @Override
        public Parameter getParameter() {
            return this.provider.getParameter();
        }

        @Override
        public int getDimension() {
            return this.provider.getDimension();
        }

        @Override
        public double[] getGradientLogDensity() {
            double[] dArray = this.provider.getGradientLogDensity();
            for (int i = 0; i < dArray.length; ++i) {
                dArray[i] = -dArray[i];
            }
            return dArray;
        }
    }
}

