/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.statistics.distribution;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.math.RingAccumulator;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.statistics.ClosedFormComputableDistribution;
import gov.sandia.cognition.statistics.ProbabilityDensityFunction;
import gov.sandia.cognition.statistics.ProbabilityFunction;
import gov.sandia.cognition.statistics.distribution.LinearMixtureModel;
import gov.sandia.cognition.util.ObjectUtil;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Collection;

@PublicationReference(author={"Wikipedia"}, title="Mixture Model", type=PublicationType.WebPage, year=2009, url="http://en.wikipedia.org/wiki/Mixture_model")
public class MultivariateMixtureDensityModel<DistributionType extends ClosedFormComputableDistribution<Vector>>
extends LinearMixtureModel<Vector, DistributionType>
implements ClosedFormComputableDistribution<Vector> {
    public MultivariateMixtureDensityModel(DistributionType ... distributions) {
        this((Collection<DistributionType>)Arrays.asList(distributions));
    }

    public MultivariateMixtureDensityModel(Collection<? extends DistributionType> distributions) {
        this(distributions, (double[])null);
    }

    public MultivariateMixtureDensityModel(Collection<? extends DistributionType> distributions, double[] priorWeights) {
        super(distributions, priorWeights);
    }

    public MultivariateMixtureDensityModel(MultivariateMixtureDensityModel<? extends DistributionType> other) {
        this(ObjectUtil.cloneSmartElementsAsArrayList(other.getDistributions()), (double[])ObjectUtil.deepCopy((Serializable)other.getPriorWeights()));
    }

    @Override
    public MultivariateMixtureDensityModel<DistributionType> clone() {
        return (MultivariateMixtureDensityModel)super.clone();
    }

    @Override
    public Vector getMean() {
        RingAccumulator mean = new RingAccumulator();
        int K = this.getDistributionCount();
        for (int k = 0; k < K; ++k) {
            mean.accumulate(((Vector)((ClosedFormComputableDistribution)this.getDistributions().get(k)).getMean()).scale(this.getPriorWeights()[k]));
        }
        return (Vector)((Vector)mean.getSum()).scale(1.0 / this.getPriorWeightSum());
    }

    public Vector convertToVector() {
        return VectorFactory.getDefault().copyArray(this.getPriorWeights());
    }

    public void convertFromVector(Vector parameters) {
        parameters.assertDimensionalityEquals(this.getDistributionCount());
        for (int k = 0; k < parameters.getDimensionality(); ++k) {
            this.priorWeights[k] = parameters.getElement(k);
        }
    }

    @Override
    public PDF<DistributionType> getProbabilityFunction() {
        return new PDF(this);
    }

    public static class PDF<DistributionType extends ClosedFormComputableDistribution<Vector>>
    extends MultivariateMixtureDensityModel<DistributionType>
    implements ProbabilityDensityFunction<Vector> {
        public PDF(DistributionType ... distributions) {
            super(distributions);
        }

        public PDF(Collection<? extends DistributionType> distributions) {
            super(distributions);
        }

        public PDF(Collection<? extends DistributionType> distributions, double[] priorWeights) {
            super(distributions, priorWeights);
        }

        public PDF(MultivariateMixtureDensityModel<? extends DistributionType> other) {
            super(other);
        }

        @Override
        public PDF<DistributionType> getProbabilityFunction() {
            return this;
        }

        @Override
        public double logEvaluate(Vector input) {
            return Math.log(this.evaluate(input));
        }

        public Double evaluate(Vector input) {
            double sum = 0.0;
            int K = this.getDistributionCount();
            for (int k = 0; k < K; ++k) {
                ProbabilityFunction pdf = ((ClosedFormComputableDistribution)this.getDistributions().get(k)).getProbabilityFunction();
                sum += (Double)pdf.evaluate(input) * this.priorWeights[k];
            }
            return sum / this.getPriorWeightSum();
        }

        public double[] computeRandomVariableProbabilities(Vector input) {
            int k;
            int K = this.getDistributionCount();
            double[] likelihoods = this.computeRandomVariableLikelihoods(input);
            double sum = 0.0;
            for (k = 0; k < K; ++k) {
                sum += likelihoods[k];
            }
            if (sum <= 0.0) {
                Arrays.fill(likelihoods, 1.0 / (double)K);
            }
            sum = 0.0;
            for (k = 0; k < K; ++k) {
                int n = k;
                likelihoods[n] = likelihoods[n] * this.priorWeights[k];
                sum += likelihoods[k];
            }
            if (sum <= 0.0) {
                Arrays.fill(likelihoods, 1.0 / (double)K);
                sum = 1.0;
            }
            k = 0;
            while (k < K) {
                int n = k++;
                likelihoods[n] = likelihoods[n] / sum;
            }
            return likelihoods;
        }

        public double[] computeRandomVariableLikelihoods(Vector input) {
            int K = this.getDistributionCount();
            double[] likelihoods = new double[K];
            for (int k = 0; k < K; ++k) {
                ProbabilityFunction pdf = ((ClosedFormComputableDistribution)this.getDistributions().get(k)).getProbabilityFunction();
                likelihoods[k] = (Double)pdf.evaluate(input);
            }
            return likelihoods;
        }

        public int getMostLikelyRandomVariable(Vector input) {
            double[] probabilities = this.computeRandomVariableProbabilities(input);
            int bestIndex = 0;
            double bestProbability = probabilities[0];
            for (int i = 1; i < probabilities.length; ++i) {
                double prob = probabilities[i];
                if (!(bestProbability < prob)) continue;
                bestProbability = prob;
                bestIndex = i;
            }
            return bestIndex;
        }
    }
}

