/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.learning.algorithm.regression;

import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.learning.algorithm.SupervisedBatchLearner;
import gov.sandia.cognition.learning.algorithm.regression.LinearRegression;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.scalar.VectorFunctionLinearDiscriminant;
import gov.sandia.cognition.learning.function.vector.DecoupledVectorFunction;
import gov.sandia.cognition.learning.function.vector.ScalarBasisSet;
import gov.sandia.cognition.math.Ring;
import gov.sandia.cognition.math.UnivariateScalarFunction;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.statistics.distribution.MultivariateGaussian;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.CloneableSerializable;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;

public class DecoupledVectorLinearRegression
extends AbstractCloneableSerializable
implements SupervisedBatchLearner<Vector, Vector, DecoupledVectorFunction> {
    private Collection<ScalarBasisSet<Double>> elementFunctions;
    private DecoupledVectorFunction learned;
    private int numParameters;

    public DecoupledVectorLinearRegression(int dimensionality, UnivariateScalarFunction ... basisFunctions) {
        this(dimensionality, Arrays.asList(basisFunctions));
    }

    public DecoupledVectorLinearRegression(int dimensionality, Collection<? extends Evaluator<? super Double, Double>> basisFunctions) {
        this(dimensionality, new ScalarBasisSet<Double>(basisFunctions));
    }

    public DecoupledVectorLinearRegression(int dimensionality, ScalarBasisSet<Double> elementFunction) {
        ArrayList<ScalarBasisSet<Double>> functions = new ArrayList<ScalarBasisSet<Double>>(dimensionality);
        for (int i = 0; i < dimensionality; ++i) {
            functions.add((ScalarBasisSet)elementFunction.clone());
        }
        this.setElementFunctions(functions);
        this.setLearned(null);
        this.setNumParameters(elementFunction.getOutputDimensionality());
    }

    public DecoupledVectorLinearRegression(Collection<ScalarBasisSet<Double>> elementFunctions) {
        this.setElementFunctions(elementFunctions);
        this.setLearned(null);
        this.setNumParameters(elementFunctions.iterator().next().getOutputDimensionality());
    }

    public DecoupledVectorLinearRegression clone() {
        DecoupledVectorLinearRegression clone = (DecoupledVectorLinearRegression)super.clone();
        clone.setElementFunctions(ObjectUtil.cloneSmartElementsAsArrayList(this.getElementFunctions()));
        return clone;
    }

    public Collection<ScalarBasisSet<Double>> getElementFunctions() {
        return this.elementFunctions;
    }

    public void setElementFunctions(Collection<ScalarBasisSet<Double>> elementFunctions) {
        if (elementFunctions.size() <= 0) {
            throw new IllegalArgumentException("Must have at least one function in the function Collection");
        }
        this.elementFunctions = elementFunctions;
    }

    public int getDimensionality() {
        return this.getElementFunctions().size();
    }

    @Override
    public DecoupledVectorFunction learn(Collection<? extends InputOutputPair<? extends Vector, Vector>> data) {
        ArrayList<ArrayList<InputOutputPair<Double, Double>>> decoupledDataset = DatasetUtil.decoupleVectorPairDataset(data);
        ArrayList<VectorFunctionLinearDiscriminant<Double>> regressionFunctions = new ArrayList<VectorFunctionLinearDiscriminant<Double>>(this.getDimensionality());
        int i = 0;
        int params = -1;
        for (ScalarBasisSet<Double> fi : this.getElementFunctions()) {
            ArrayList<InputOutputPair<Double, Double>> rowDataset = decoupledDataset.get(i);
            LinearRegression<Double> lri = new LinearRegression<Double>(fi);
            regressionFunctions.add(lri.learn((Collection<InputOutputPair<Double, Double>>)rowDataset));
            if (params < 0) {
                params = lri.getLearned().convertToVector().getDimensionality();
            }
            ++i;
        }
        this.setLearned(new DecoupledVectorFunction(regressionFunctions));
        this.setNumParameters(params);
        return this.getLearned();
    }

    public Statistic computeStatistics(Collection<? extends InputOutputPair<Vector, Vector>> data) {
        ArrayList<Vector> targets = new ArrayList<Vector>(data.size());
        ArrayList<Vector> estimates = new ArrayList<Vector>(data.size());
        ArrayList<Double> weights = new ArrayList<Double>(data.size());
        for (InputOutputPair<Vector, Vector> inputOutputPair : data) {
            double weight = DatasetUtil.getWeight(inputOutputPair);
            targets.add(inputOutputPair.getOutput());
            estimates.add(this.getLearned().evaluate(inputOutputPair.getInput()));
            weights.add(weight);
        }
        return new Statistic(targets, estimates, weights, this.getNumParameters());
    }

    public DecoupledVectorFunction getLearned() {
        return this.learned;
    }

    public void setLearned(DecoupledVectorFunction learned) {
        this.learned = learned;
    }

    public int getNumParameters() {
        return this.numParameters;
    }

    public void setNumParameters(int numParameters) {
        this.numParameters = numParameters;
    }

    public static class Statistic
    extends AbstractCloneableSerializable {
        public static final double SMALL_COVARIANCE = 1.0E-10;
        private Collection<LinearRegression.Statistic> elementStatistics;
        private MultivariateGaussian jointErrorStatistics;

        public Statistic(Collection<Vector> targets, Collection<Vector> estimates, Collection<Double> weights, int numParameters) {
            ArrayList<ArrayList<Double>> decoupledTargets = DatasetUtil.decoupleVectorDataset(targets);
            ArrayList<ArrayList<Double>> decoupledEstimates = DatasetUtil.decoupleVectorDataset(estimates);
            if (targets.size() != estimates.size() || targets.size() != weights.size()) {
                throw new IllegalArgumentException("Number of targets must equal the number of estimates");
            }
            int num = targets.size();
            if (decoupledTargets.size() != decoupledEstimates.size()) {
                throw new IllegalArgumentException("Dimensionality of targets aren't estimates");
            }
            int M = decoupledTargets.size();
            ArrayList<LinearRegression.Statistic> statistics = new ArrayList<LinearRegression.Statistic>(M);
            for (int i = 0; i < M; ++i) {
                statistics.add(new LinearRegression.Statistic((Collection<Double>)decoupledTargets.get(i), (Collection<Double>)decoupledEstimates.get(i), weights, numParameters));
            }
            ArrayList<Ring> errors = new ArrayList<Ring>(num);
            Iterator<Vector> ti = targets.iterator();
            Iterator<Vector> ei = estimates.iterator();
            for (int n = 0; n < num; ++n) {
                errors.add(ti.next().minus((Ring)ei.next()));
            }
            MultivariateGaussian.PDF errorCovarianceStatistics = MultivariateGaussian.MaximumLikelihoodEstimator.learn(errors, 1.0E-10);
            this.setJointErrorStatistics(errorCovarianceStatistics);
            this.setElementStatistics(statistics);
        }

        public Statistic clone() {
            Statistic clone = (Statistic)super.clone();
            clone.setJointErrorStatistics((MultivariateGaussian)ObjectUtil.cloneSafe((CloneableSerializable)this.getJointErrorStatistics()));
            clone.setElementStatistics(ObjectUtil.cloneSmartElementsAsArrayList(this.getElementStatistics()));
            return clone;
        }

        public Collection<LinearRegression.Statistic> getElementStatistics() {
            return this.elementStatistics;
        }

        public void setElementStatistics(Collection<LinearRegression.Statistic> elementStatistics) {
            this.elementStatistics = elementStatistics;
        }

        public MultivariateGaussian getJointErrorStatistics() {
            return this.jointErrorStatistics;
        }

        public void setJointErrorStatistics(MultivariateGaussian jointErrorStatistics) {
            this.jointErrorStatistics = jointErrorStatistics;
        }
    }
}

