/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.statistics.distribution;

import gov.sandia.cognition.annotation.CodeReview;
import gov.sandia.cognition.annotation.CodeReviewResponse;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.math.ComplexNumber;
import gov.sandia.cognition.math.MultivariateStatisticsUtil;
import gov.sandia.cognition.math.Ring;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.VectorInputEvaluator;
import gov.sandia.cognition.math.matrix.mtj.DenseMatrix;
import gov.sandia.cognition.math.matrix.mtj.DenseMatrixFactoryMTJ;
import gov.sandia.cognition.math.matrix.mtj.decomposition.CholeskyDecompositionMTJ;
import gov.sandia.cognition.statistics.AbstractDistribution;
import gov.sandia.cognition.statistics.AbstractIncrementalEstimator;
import gov.sandia.cognition.statistics.AbstractSufficientStatistic;
import gov.sandia.cognition.statistics.ClosedFormComputableDistribution;
import gov.sandia.cognition.statistics.DistributionEstimator;
import gov.sandia.cognition.statistics.DistributionWeightedEstimator;
import gov.sandia.cognition.statistics.EstimableDistribution;
import gov.sandia.cognition.statistics.ProbabilityDensityFunction;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.CloneableSerializable;
import gov.sandia.cognition.util.ObjectUtil;
import gov.sandia.cognition.util.Pair;
import gov.sandia.cognition.util.WeightedValue;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Random;

@CodeReview(reviewer={"Jonathan McClain"}, date="2006-05-15", changesNeeded=true, comments={"A few minor changes needed.", "Comments indicated with a / / / comment in the first column."}, response={@CodeReviewResponse(respondent="Kevin R. Dixon", date="2006-05-15", moreChangesNeeded=false, comments={"Fixed."})})
@PublicationReference(author={"Wikipedia"}, title="Multivariate normal distribution", type=PublicationType.WebPage, year=2009, url="http://en.wikipedia.org/wiki/Multivariate_normal_distribution")
public class MultivariateGaussian
extends AbstractDistribution<Vector>
implements ClosedFormComputableDistribution<Vector>,
EstimableDistribution<Vector, MultivariateGaussian> {
    public static final double DEFAULT_COVARIANCE_SYMMETRY_TOLERANCE = 1.0E-5;
    public static final int DEFAULT_DIMENSIONALITY = 2;
    public static final double LOG_TWO_PI = Math.log(Math.PI * 2);
    private Vector mean;
    private Matrix covariance;
    private Double logCovarianceDeterminant;
    private Matrix covarianceInverse;
    private Double logLeadingCoefficient;

    public MultivariateGaussian() {
        this(2);
    }

    public MultivariateGaussian(int dimensionality) {
        this(VectorFactory.getDefault().createVector(dimensionality), MatrixFactory.getDefault().createIdentity(dimensionality, dimensionality));
    }

    public MultivariateGaussian(Vector mean, Matrix covariance) {
        this.setMean(mean);
        this.setCovariance(covariance);
    }

    public MultivariateGaussian(MultivariateGaussian other) {
        this((Vector)ObjectUtil.cloneSafe((CloneableSerializable)other.getMean()), (Matrix)ObjectUtil.cloneSafe((CloneableSerializable)other.getCovariance()));
    }

    public MultivariateGaussian clone() {
        MultivariateGaussian clone = (MultivariateGaussian)super.clone();
        clone.setMean((Vector)ObjectUtil.cloneSafe((CloneableSerializable)this.getMean()));
        clone.setCovariance((Matrix)ObjectUtil.cloneSafe((CloneableSerializable)this.getCovariance()));
        return clone;
    }

    public PDF getProbabilityFunction() {
        return new PDF(this);
    }

    public double computeZSquared(Vector input) {
        Vector delta = (Vector)input.minus((Ring)this.mean);
        double zsquared = delta.times(this.getCovarianceInverse()).dotProduct(delta);
        return zsquared;
    }

    public MultivariateGaussian times(MultivariateGaussian other) {
        Vector m1 = this.mean;
        Matrix c1inv = this.getCovarianceInverse();
        Vector m2 = other.getMean();
        Matrix c2inv = other.getCovarianceInverse();
        Matrix Cinv = (Matrix)c1inv.plus((Ring)c2inv);
        Matrix C = Cinv.inverse();
        Vector m = C.times((Vector)c1inv.times(m1).plus((Ring)c2inv.times(m2)));
        return new MultivariateGaussian(m, C);
    }

    public MultivariateGaussian convolve(MultivariateGaussian other) {
        Vector meanHat = (Vector)this.mean.plus((Ring)other.getMean());
        Matrix covarianceHat = (Matrix)this.getCovariance().plus((Ring)other.getCovariance());
        return new MultivariateGaussian(meanHat, covarianceHat);
    }

    public int getInputDimensionality() {
        return this.mean != null ? this.mean.getDimensionality() : 0;
    }

    @Override
    public Vector getMean() {
        return this.mean;
    }

    public void setMean(Vector mean) {
        if (mean == null) {
            throw new NullPointerException("Mean cannot be null.");
        }
        this.mean = mean;
    }

    public Matrix getCovariance() {
        if (this.covariance == null) {
            this.covariance = this.covarianceInverse.inverse();
        }
        return this.covariance;
    }

    public void setCovariance(Matrix covariance) {
        this.setCovariance(covariance, 1.0E-5);
    }

    public void setCovariance(Matrix covariance, double symmetryTolerance) {
        if (!covariance.isSymmetric(symmetryTolerance)) {
            covariance = covariance.clone();
            int N = covariance.getNumRows();
            for (int i = 1; i < N; ++i) {
                for (int j = 0; j < i; ++j) {
                    double vji;
                    double vij = covariance.getElement(i, j);
                    if (vij == (vji = covariance.getElement(j, i))) continue;
                    double v = (vij + vji) / 2.0;
                    covariance.setElement(i, j, v);
                    covariance.setElement(j, i, v);
                }
            }
        }
        this.covariance = covariance;
        this.covarianceInverse = null;
        this.logCovarianceDeterminant = null;
        this.logLeadingCoefficient = null;
    }

    public Matrix getCovarianceInverse() {
        if (this.covarianceInverse == null) {
            this.covarianceInverse = this.covariance.inverse();
        }
        return this.covarianceInverse;
    }

    public void setCovarianceInverse(Matrix covarianceInverse) {
        this.setCovarianceInverse(covarianceInverse, 1.0E-5);
    }

    public void setCovarianceInverse(Matrix covarianceInverse, double symmetryTolerance) {
        if (!covarianceInverse.isSymmetric(symmetryTolerance)) {
            covarianceInverse = covarianceInverse.clone();
            int N = covarianceInverse.getNumRows();
            for (int i = 1; i < N; ++i) {
                for (int j = 0; j < i; ++j) {
                    double vji;
                    double vij = covarianceInverse.getElement(i, j);
                    if (vij == (vji = covarianceInverse.getElement(j, i))) continue;
                    double v = (vij + vji) / 2.0;
                    covarianceInverse.setElement(i, j, v);
                    covarianceInverse.setElement(j, i, v);
                }
            }
        }
        this.covarianceInverse = covarianceInverse;
        this.covariance = null;
        this.logCovarianceDeterminant = null;
        this.logLeadingCoefficient = null;
    }

    public double getLogCovarianceDeterminant() {
        if (this.logCovarianceDeterminant == null) {
            ComplexNumber logDeterminant = this.covariance.logDeterminant();
            this.logCovarianceDeterminant = logDeterminant.getRealPart();
        }
        return this.logCovarianceDeterminant;
    }

    public double getLogLeadingCoefficient() {
        if (this.logLeadingCoefficient == null) {
            int k = this.getInputDimensionality();
            this.logLeadingCoefficient = -0.5 * (double)k * LOG_TWO_PI + -0.5 * this.getLogCovarianceDeterminant();
        }
        return this.logLeadingCoefficient;
    }

    public boolean equals(Object randomVariable) {
        boolean retval = false;
        if (randomVariable instanceof MultivariateGaussian) {
            MultivariateGaussian other = (MultivariateGaussian)randomVariable;
            retval = this.getMean().equals(other.getMean()) && this.getCovariance().equals(other.getCovariance());
        }
        return retval;
    }

    public int hashCode() {
        int hash = 7;
        hash = 53 * hash + ObjectUtil.hashCodeSafe((Object)this.mean);
        hash = 53 * hash + ObjectUtil.hashCodeSafe((Object)this.getCovariance());
        return hash;
    }

    @Override
    public ArrayList<Vector> sample(Random random, int numSamples) {
        DenseMatrix covSqrt = CholeskyDecompositionMTJ.create((DenseMatrix)DenseMatrixFactoryMTJ.INSTANCE.copyMatrix(this.getCovariance())).getR();
        return MultivariateGaussian.sample(this.mean, (Matrix)covSqrt, random, numSamples);
    }

    public static ArrayList<Vector> sample(Vector mean, Matrix covarianceSquareRoot, Random random, int numDraws) {
        ArrayList<Vector> retval = new ArrayList<Vector>(numDraws);
        for (int n = 0; n < numDraws; ++n) {
            retval.add(MultivariateGaussian.sample(mean, covarianceSquareRoot, random));
        }
        return retval;
    }

    public static Vector sample(Vector mean, Matrix covarianceSquareRoot, Random random) {
        int M = covarianceSquareRoot.getNumRows();
        Vector x = VectorFactory.getDefault().createVector(M);
        for (int i = 0; i < M; ++i) {
            x.setElement(i, random.nextGaussian());
        }
        Vector sample = covarianceSquareRoot.times(x);
        sample.plusEquals((Ring)mean);
        return sample;
    }

    public MultivariateGaussian scale(Matrix premultiplyMatrix) {
        Vector m = premultiplyMatrix.times(this.mean);
        Matrix C = premultiplyMatrix.times(this.getCovariance()).times(premultiplyMatrix.transpose());
        return new MultivariateGaussian(m, C);
    }

    public MultivariateGaussian plus(MultivariateGaussian other) {
        Vector m = (Vector)this.mean.plus((Ring)other.getMean());
        Matrix C = (Matrix)this.getCovariance().plus((Ring)other.getCovariance());
        return new MultivariateGaussian(m, C);
    }

    public String toString() {
        String retval = "Mean: " + this.getMean() + "\n" + "Covariance:\n" + this.getCovariance();
        return retval;
    }

    public Vector convertToVector() {
        return this.mean.stack(this.getCovariance().convertToVector());
    }

    public void convertFromVector(Vector parameters) {
        int N = this.getInputDimensionality();
        this.setMean(parameters.subVector(0, N - 1));
        Matrix m = this.getCovariance();
        m.convertFromVector(parameters.subVector(N, parameters.getDimensionality() - 1));
        this.setCovariance(m);
    }

    public MaximumLikelihoodEstimator getEstimator() {
        return new MaximumLikelihoodEstimator();
    }

    public static class IncrementalEstimator
    extends AbstractIncrementalEstimator<Vector, MultivariateGaussian, SufficientStatistic> {
        public static final double DEFAULT_COVARIANCE = 1.0E-5;
        private double defaultCovariance;

        public IncrementalEstimator() {
            this(1.0E-5);
        }

        public IncrementalEstimator(double defaultCovariance) {
            this.setDefaultCovariance(defaultCovariance);
        }

        public double getDefaultCovariance() {
            return this.defaultCovariance;
        }

        public void setDefaultCovariance(double defaultCovariance) {
            this.defaultCovariance = defaultCovariance;
        }

        @Override
        public SufficientStatistic createInitialLearnedObject() {
            return new SufficientStatistic(this.getDefaultCovariance());
        }
    }

    public static class SufficientStatistic
    extends AbstractSufficientStatistic<Vector, MultivariateGaussian> {
        public static final double DEFAULT_COVARIANCE = 1.0E-5;
        private Vector mean;
        private Matrix sumSquaredDifferences;
        protected double defaultCovariance;

        public SufficientStatistic() {
            this(1.0E-5);
        }

        public SufficientStatistic(double defaultCovariance) {
            this.clear();
            this.defaultCovariance = defaultCovariance;
        }

        @Override
        public SufficientStatistic clone() {
            return (SufficientStatistic)super.clone();
        }

        public void clear() {
            this.count = 0L;
            this.mean = null;
            this.sumSquaredDifferences = null;
        }

        @Override
        public void update(Vector value) {
            ++this.count;
            int dim = value.getDimensionality();
            if (this.mean == null) {
                this.mean = VectorFactory.getDefault().createVector(dim);
            }
            Vector delta = (Vector)value.minus((Ring)this.mean);
            this.mean.plusEquals(delta.scale(1.0 / (double)this.count));
            if (this.sumSquaredDifferences == null) {
                this.sumSquaredDifferences = MatrixFactory.getDefault().createIdentity(dim, dim);
                this.sumSquaredDifferences.scaleEquals(this.getDefaultCovariance());
            }
            Vector delta2 = (Vector)value.minus((Ring)this.mean);
            this.sumSquaredDifferences.plusEquals((Ring)delta.outerProduct(delta2));
        }

        public PDF create() {
            return new PDF(this.getMean(), this.getCovariance());
        }

        @Override
        public void create(MultivariateGaussian distribution) {
            distribution.setMean(this.getMean());
            distribution.setCovariance(this.getCovariance());
        }

        public double getDefaultCovariance() {
            return this.defaultCovariance;
        }

        public void setDefaultCovariance(double defaultCovariance) {
            this.defaultCovariance = defaultCovariance;
        }

        public Vector getMean() {
            return this.mean;
        }

        public Matrix getSumSquaredDifferences() {
            return this.sumSquaredDifferences;
        }

        public Matrix getCovariance() {
            if (this.count <= 0L) {
                return null;
            }
            if (this.count == 1L) {
                return this.sumSquaredDifferences.clone();
            }
            return (Matrix)this.sumSquaredDifferences.scale(1.0 / ((double)this.count - 1.0));
        }
    }

    public static class WeightedMaximumLikelihoodEstimator
    extends AbstractCloneableSerializable
    implements DistributionWeightedEstimator<Vector, PDF> {
        public static final double DEFAULT_COVARIANCE = 1.0E-5;
        private double defaultCovariance;

        public WeightedMaximumLikelihoodEstimator() {
            this(1.0E-5);
        }

        public WeightedMaximumLikelihoodEstimator(double defaultCovariance) {
            this.defaultCovariance = defaultCovariance;
        }

        @Override
        public PDF learn(Collection<? extends WeightedValue<? extends Vector>> data) {
            return WeightedMaximumLikelihoodEstimator.learn(data, this.defaultCovariance);
        }

        public static PDF learn(Collection<? extends WeightedValue<? extends Vector>> data, double defaultCovariance) {
            int N = data.size();
            if (N <= 1) {
                throw new IllegalArgumentException("The number of samples must be greater than 1.");
            }
            Pair mle = MultivariateStatisticsUtil.computeWeightedMeanAndCovariance(data);
            Vector mean = (Vector)mle.getFirst();
            Matrix covariance = (Matrix)mle.getSecond();
            if (defaultCovariance != 0.0) {
                int M = covariance.getNumRows();
                for (int i = 0; i < M; ++i) {
                    double v = covariance.getElement(i, i);
                    covariance.setElement(i, i, v + defaultCovariance);
                }
            }
            return new PDF(mean, covariance);
        }
    }

    public static class MaximumLikelihoodEstimator
    extends AbstractCloneableSerializable
    implements DistributionEstimator<Vector, PDF> {
        public static final double DEFAULT_COVARIANCE = 1.0E-5;
        private double defaultCovariance;

        public MaximumLikelihoodEstimator() {
            this(1.0E-5);
        }

        public MaximumLikelihoodEstimator(double defaultCovariance) {
            this.defaultCovariance = defaultCovariance;
        }

        public static PDF learn(Collection<? extends Vector> data, double defaultCovariance) {
            int N = data.size();
            if (N <= 1) {
                throw new IllegalArgumentException("Need at least 2 data points to compute covariance");
            }
            Pair mle = MultivariateStatisticsUtil.computeMeanAndCovariance(data);
            Vector mean = (Vector)mle.getFirst();
            Matrix covariance = (Matrix)mle.getSecond();
            if (defaultCovariance != 0.0) {
                int M = mean.getDimensionality();
                for (int i = 0; i < M; ++i) {
                    double v = covariance.getElement(i, i);
                    covariance.setElement(i, i, v + defaultCovariance);
                }
            }
            return new PDF(mean, covariance);
        }

        @Override
        public PDF learn(Collection<? extends Vector> data) {
            return MaximumLikelihoodEstimator.learn(data, this.defaultCovariance);
        }
    }

    public static class PDF
    extends MultivariateGaussian
    implements ProbabilityDensityFunction<Vector>,
    VectorInputEvaluator<Vector, Double> {
        public PDF() {
        }

        public PDF(int dimensionality) {
            super(dimensionality);
        }

        public PDF(Vector mean, Matrix covariance) {
            super(mean, covariance);
        }

        public PDF(MultivariateGaussian other) {
            super(other);
        }

        public Double evaluate(Vector input) {
            return Math.exp(this.logEvaluate(input));
        }

        @Override
        public double logEvaluate(Vector input) {
            double zsquared = this.computeZSquared(input);
            return this.getLogLeadingCoefficient() - 0.5 * zsquared;
        }

        @Override
        public PDF getProbabilityFunction() {
            return this;
        }
    }
}

