/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.learning.algorithm.hmm;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.learning.algorithm.BatchLearner;
import gov.sandia.cognition.learning.algorithm.hmm.MarkovChain;
import gov.sandia.cognition.math.Ring;
import gov.sandia.cognition.math.RingAccumulator;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.statistics.ComputableDistribution;
import gov.sandia.cognition.statistics.Distribution;
import gov.sandia.cognition.statistics.ProbabilityFunction;
import gov.sandia.cognition.statistics.distribution.DirichletDistribution;
import gov.sandia.cognition.util.DefaultPair;
import gov.sandia.cognition.util.DefaultWeightedValue;
import gov.sandia.cognition.util.ObjectUtil;
import gov.sandia.cognition.util.Pair;
import gov.sandia.cognition.util.WeightedValue;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Random;

@PublicationReference(author={"Lawrence R. Rabiner"}, title="A tutorial on hidden Markov models and selected applications in speech recognition", type=PublicationType.Journal, year=1989, publication="Proceedings of the IEEE", pages={257, 286}, url="http://www.cs.ubc.ca/~murphyk/Bayes/rabiner.pdf", notes={"Rabiner's transition matrix is transposed from mine."})
public class HiddenMarkovModel<ObservationType>
extends MarkovChain
implements Distribution<ObservationType> {
    protected Collection<? extends ComputableDistribution<ObservationType>> emissionFunctions;

    public HiddenMarkovModel() {
    }

    public HiddenMarkovModel(int numStates) {
        super(numStates);
    }

    public HiddenMarkovModel(Vector initialProbability, Matrix transitionProbability, Collection<? extends ComputableDistribution<ObservationType>> emissionFunctions) {
        super(initialProbability, transitionProbability);
        int k = this.getNumStates();
        if (emissionFunctions.size() != k) {
            throw new IllegalArgumentException("Number of PDFs must be equal to number of states!");
        }
        this.setEmissionFunctions(emissionFunctions);
    }

    public static <ObservationType> HiddenMarkovModel<ObservationType> createRandom(int numStates, BatchLearner<Collection<? extends WeightedValue<? extends ObservationType>>, ? extends ComputableDistribution<ObservationType>> learner, Collection<? extends ObservationType> data, Random random) {
        ArrayList<DefaultWeightedValue> weightedData = new ArrayList<DefaultWeightedValue>(data.size());
        for (ObservationType observation : data) {
            weightedData.add(new DefaultWeightedValue(observation, 1.0));
        }
        ComputableDistribution<ObservationType> distribution = learner.learn(weightedData);
        return HiddenMarkovModel.createRandom(numStates, distribution, random);
    }

    public static <ObservationType> HiddenMarkovModel<ObservationType> createRandom(int numStates, ComputableDistribution<ObservationType> distribution, Random random) {
        List<ProbabilityFunction<ObservationType>> distributions = Collections.nCopies(numStates, distribution.getProbabilityFunction());
        return HiddenMarkovModel.createRandom(distributions, random);
    }

    public static <ObservationType> HiddenMarkovModel<ObservationType> createRandom(Collection<? extends ProbabilityFunction<ObservationType>> distributions, Random random) {
        int numStates = distributions.size();
        DirichletDistribution dirichlet = new DirichletDistribution(numStates);
        Vector initialProbability = (Vector)dirichlet.sample(random);
        Matrix transitionMatrix = MatrixFactory.getDefault().createMatrix(numStates, numStates);
        ArrayList<Vector> outbounds = dirichlet.sample(random, numStates);
        for (int i = 0; i < numStates; ++i) {
            transitionMatrix.setColumn(i, outbounds.get(i));
        }
        return new HiddenMarkovModel<ObservationType>(initialProbability, transitionMatrix, distributions);
    }

    @Override
    public HiddenMarkovModel<ObservationType> clone() {
        HiddenMarkovModel clone = (HiddenMarkovModel)super.clone();
        clone.setEmissionFunctions(ObjectUtil.cloneSmartElementsAsArrayList(this.getEmissionFunctions()));
        return clone;
    }

    public double computeObservationLogLikelihood(Collection<? extends ObservationType> observations) {
        int k = this.getNumStates();
        Vector b = VectorFactory.getDefault().createVector(k);
        Vector alpha = this.getInitialProbability().clone();
        Matrix A = this.getTransitionProbability();
        int index = 0;
        double logLikelihood = 0.0;
        for (ObservationType observation : observations) {
            if (index > 0) {
                alpha = A.times(alpha);
            }
            this.computeObservationLikelihoods(observation, b);
            alpha.dotTimesEquals((Ring)b);
            double weight = alpha.norm1();
            alpha.scaleEquals(1.0 / weight);
            logLikelihood += Math.log(weight);
            ++index;
        }
        return logLikelihood;
    }

    protected double computeMultipleObservationLogLikelihood(Collection<? extends Collection<? extends ObservationType>> sequences) {
        double logLikelihood = 0.0;
        for (Collection<ObservationType> collection : sequences) {
            logLikelihood += this.computeObservationLogLikelihood(collection);
        }
        return logLikelihood;
    }

    public double computeObservationLogLikelihood(Collection<? extends ObservationType> observations, Collection<Integer> states) {
        int N = observations.size();
        if (N != states.size()) {
            throw new IllegalArgumentException("Observations and states must be the same size");
        }
        Iterator<Integer> stateIterator = states.iterator();
        double logLikelihood = 0.0;
        ArrayList<ProbabilityFunction<ObservationType>> fs = new ArrayList<ProbabilityFunction<ObservationType>>(this.getNumStates());
        for (ComputableDistribution<ObservationType> f : this.getEmissionFunctions()) {
            fs.add(f.getProbabilityFunction());
        }
        int lastState = -1;
        for (ObservationType observation : observations) {
            int state = stateIterator.next();
            double blog = ((ProbabilityFunction)fs.get(state)).logEvaluate(observation);
            double ll = lastState < 0 ? Math.log(this.initialProbability.getElement(state)) : Math.log(this.transitionProbability.getElement(state, lastState));
            lastState = state;
            logLikelihood += blog + ll;
        }
        return logLikelihood;
    }

    @Override
    public ObservationType sample(Random random) {
        return (ObservationType)CollectionUtil.getFirst(this.sample(random, 1));
    }

    @Override
    public ArrayList<ObservationType> sample(Random random, int numSamples) {
        ArrayList samples = new ArrayList(numSamples);
        Vector p = this.getInitialProbability();
        int state = -1;
        for (int n = 0; n < numSamples; ++n) {
            state = -1;
            for (double value = random.nextDouble(); value > 0.0; value -= p.getElement(++state)) {
            }
            Object sample = ((ComputableDistribution)CollectionUtil.getElement(this.getEmissionFunctions(), (int)state)).sample(random);
            samples.add(sample);
            p = this.getTransitionProbability().getColumn(state);
        }
        return samples;
    }

    public Collection<? extends ComputableDistribution<ObservationType>> getEmissionFunctions() {
        return this.emissionFunctions;
    }

    public void setEmissionFunctions(Collection<? extends ComputableDistribution<ObservationType>> emissionFunctions) {
        this.emissionFunctions = emissionFunctions;
    }

    protected WeightedValue<Vector> computeForwardProbabilities(Vector alpha, Vector b, boolean normalize) {
        double weight;
        Vector alphaNext = this.getTransitionProbability().times(alpha);
        alphaNext.dotTimesEquals((Ring)b);
        if (normalize) {
            weight = 1.0 / alphaNext.norm1();
            alphaNext.scaleEquals(weight);
        } else {
            weight = 1.0;
        }
        return new DefaultWeightedValue((Object)alphaNext, weight);
    }

    protected ArrayList<WeightedValue<Vector>> computeForwardProbabilities(ArrayList<Vector> b, boolean normalize) {
        double weight;
        int N = b.size();
        ArrayList<WeightedValue<Vector>> weightedAlphas = new ArrayList<WeightedValue<Vector>>(N);
        Vector alpha = (Vector)b.get(0).dotTimes((Ring)this.getInitialProbability());
        if (normalize) {
            weight = 1.0 / alpha.norm1();
            alpha.scaleEquals(weight);
        } else {
            weight = 1.0;
        }
        WeightedValue<Vector> weightedAlpha = new WeightedValue<Vector>((Object)alpha, weight);
        weightedAlphas.add(weightedAlpha);
        for (int n = 1; n < N; ++n) {
            weightedAlpha = this.computeForwardProbabilities((Vector)weightedAlpha.getValue(), b.get(n), normalize);
            weightedAlphas.add(weightedAlpha);
        }
        return weightedAlphas;
    }

    public Vector computeObservationLikelihoods(ObservationType observation) {
        int k = this.getEmissionFunctions().size();
        Vector b = VectorFactory.getDefault().createVector(k);
        this.computeObservationLikelihoods(observation, b);
        return b;
    }

    protected void computeObservationLikelihoods(ObservationType observation, Vector b) {
        int i = 0;
        for (ComputableDistribution<ObservationType> f : this.getEmissionFunctions()) {
            b.setElement(i, ((Double)f.getProbabilityFunction().evaluate(observation)).doubleValue());
            ++i;
        }
    }

    protected ArrayList<Vector> computeObservationLikelihoods(Collection<? extends ObservationType> observations) {
        int N = observations.size();
        ArrayList<Vector> bs = new ArrayList<Vector>(N);
        for (ObservationType observation : observations) {
            bs.add(this.computeObservationLikelihoods(observation));
        }
        return bs;
    }

    protected WeightedValue<Vector> computeBackwardProbabilities(Vector beta, Vector b, double weight) {
        Vector betaPrevious = (Vector)b.dotTimes((Ring)beta);
        betaPrevious = betaPrevious.times(this.getTransitionProbability());
        if (weight != 1.0) {
            betaPrevious.scaleEquals(weight);
        }
        return new DefaultWeightedValue((Object)betaPrevious, weight);
    }

    protected ArrayList<WeightedValue<Vector>> computeBackwardProbabilities(ArrayList<Vector> b, ArrayList<WeightedValue<Vector>> alphas) {
        int N = b.size();
        int k = this.getInitialProbability().getDimensionality();
        ArrayList<WeightedValue<Vector>> weightedBetas = new ArrayList<WeightedValue<Vector>>(N);
        for (int n = 0; n < N; ++n) {
            weightedBetas.add(null);
        }
        Vector beta = VectorFactory.getDefault().createVector(k, 1.0);
        double weight = alphas.get(N - 1).getWeight();
        if (weight != 1.0) {
            beta.scaleEquals(weight);
        }
        WeightedValue<Vector> weightedBeta = new WeightedValue<Vector>((Object)beta, weight);
        weightedBetas.set(N - 1, weightedBeta);
        for (int n = N - 2; n >= 0; --n) {
            weight = alphas.get(n).getWeight();
            weightedBeta = this.computeBackwardProbabilities((Vector)weightedBeta.getValue(), b.get(n + 1), weight);
            weightedBetas.set(n, weightedBeta);
        }
        return weightedBetas;
    }

    protected static Vector computeStateObservationLikelihood(Vector alpha, Vector beta, double scaleFactor) {
        Vector gamma = (Vector)alpha.dotTimes((Ring)beta);
        gamma.scaleEquals(scaleFactor / gamma.norm1());
        return gamma;
    }

    protected ArrayList<Vector> computeStateObservationLikelihood(ArrayList<WeightedValue<Vector>> alphas, ArrayList<WeightedValue<Vector>> betas, double scaleFactor) {
        int N = alphas.size();
        ArrayList<Vector> gammas = new ArrayList<Vector>(N);
        for (int n = 0; n < N; ++n) {
            Vector alphan = (Vector)alphas.get(n).getValue();
            Vector betan = (Vector)betas.get(n).getValue();
            gammas.add(HiddenMarkovModel.computeStateObservationLikelihood(alphan, betan, scaleFactor));
        }
        return gammas;
    }

    protected static Matrix computeTransitions(Vector alphan, Vector betanp1, Vector bnp1) {
        Vector bnext = (Vector)bnp1.dotTimes((Ring)betanp1);
        return bnext.outerProduct(alphan);
    }

    protected Matrix computeTransitions(ArrayList<WeightedValue<Vector>> alphas, ArrayList<WeightedValue<Vector>> betas, ArrayList<Vector> b) {
        int N = b.size();
        RingAccumulator counts = new RingAccumulator();
        for (int n = 0; n < N - 1; ++n) {
            Vector alpha = (Vector)alphas.get(n).getValue();
            Vector beta = (Vector)betas.get(n + 1).getValue();
            counts.accumulate((Ring)HiddenMarkovModel.computeTransitions(alpha, beta, b.get(n + 1)));
        }
        Matrix A = (Matrix)counts.getSum();
        A.dotTimesEquals((Ring)this.getTransitionProbability());
        this.normalizeTransitionMatrix(A);
        return A;
    }

    @Override
    public String toString() {
        StringBuilder retval = new StringBuilder(super.toString());
        for (ComputableDistribution<ObservationType> f : this.getEmissionFunctions()) {
            retval.append("F: ");
            retval.append(f.toString());
        }
        return retval.toString();
    }

    protected WeightedValue<Integer> findMostLikelyState(int destinationState, Vector delta) {
        double best = Double.NEGATIVE_INFINITY;
        int index = -1;
        int k = delta.getDimensionality();
        for (int j = 0; j < k; ++j) {
            double dj = this.transitionProbability.getElement(destinationState, j) * delta.getElement(j);
            if (!(best < dj)) continue;
            best = dj;
            index = j;
        }
        return new DefaultWeightedValue((Object)index, best);
    }

    protected Pair<Vector, int[]> computeViterbiRecursion(Vector delta, Vector bn) {
        int k = delta.getDimensionality();
        Vector dn = VectorFactory.getDefault().createVector(k);
        int[] psi = new int[k];
        for (int i = 0; i < k; ++i) {
            WeightedValue<Integer> transition = this.findMostLikelyState(i, delta);
            psi[i] = (Integer)transition.getValue();
            dn.setElement(i, transition.getWeight());
        }
        dn.dotTimesEquals((Ring)bn);
        delta = dn;
        delta.scaleEquals(1.0 / delta.norm1());
        return DefaultPair.create((Object)delta, (Object)psi);
    }

    @PublicationReference(author={"Wikipedia"}, title="Viterbi algorithm", year=2010, type=PublicationType.WebPage, url="http://en.wikipedia.org/wiki/Viterbi_algorithm")
    public ArrayList<Integer> viterbi(Collection<? extends ObservationType> observations) {
        int N = observations.size();
        int k = this.getNumStates();
        ArrayList<Vector> bs = this.computeObservationLikelihoods(observations);
        Vector delta = (Vector)this.getInitialProbability().dotTimes((Ring)bs.get(0));
        ArrayList<Object> psis = new ArrayList<Object>(N);
        int[] psi = new int[k];
        for (int i = 0; i < k; ++i) {
            psi[i] = 0;
        }
        psis.add(psi);
        ArrayList<Integer> states = new ArrayList<Integer>(N);
        states.add(null);
        for (int n = 1; n < N; ++n) {
            states.add(null);
            Pair<Vector, int[]> pair = this.computeViterbiRecursion(delta, bs.get(n));
            delta = (Vector)pair.getFirst();
            psis.add(pair.getSecond());
        }
        int finalState = -1;
        double best = Double.NEGATIVE_INFINITY;
        for (int i = 0; i < k; ++i) {
            double v = delta.getElement(i);
            if (!(best < v)) continue;
            best = v;
            finalState = i;
        }
        int state = finalState;
        states.set(N - 1, state);
        for (int n = N - 2; n >= 0; --n) {
            state = ((int[])psis.get(n + 1))[state];
            states.set(n, state);
        }
        return states;
    }

    public ArrayList<Vector> stateBeliefs(Collection<? extends ObservationType> observations) {
        ArrayList<Vector> bs = this.computeObservationLikelihoods(observations);
        ArrayList<WeightedValue<Vector>> alphas = this.computeForwardProbabilities(bs, true);
        ArrayList<Vector> beliefs = new ArrayList<Vector>(alphas.size());
        for (WeightedValue<Vector> alpha : alphas) {
            beliefs.add((Vector)alpha.getValue());
        }
        return beliefs;
    }
}

