/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.learning.algorithm.pca;

import gov.sandia.cognition.annotation.CodeReview;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner;
import gov.sandia.cognition.learning.algorithm.pca.PrincipalComponentsAnalysis;
import gov.sandia.cognition.learning.algorithm.pca.PrincipalComponentsAnalysisFunction;
import gov.sandia.cognition.learning.function.vector.MatrixMultiplyVectorFunction;
import gov.sandia.cognition.math.MultivariateStatisticsUtil;
import gov.sandia.cognition.math.Ring;
import gov.sandia.cognition.math.RingAccumulator;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.util.CloneableSerializable;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.ArrayList;
import java.util.Collection;

@CodeReview(reviewer={"Kevin R. Dixon"}, date="2008-07-23", changesNeeded=false, comments={"Added PublicationReference to Sanger's master's thesis.", "Minor changes to javadoc.", "Looks fine."})
@PublicationReference(author={"Terrence D. Sanger"}, title="Optimal Unsupervised Learning in a Single-Layer Linear Feedforward Neural Network", type=PublicationType.Thesis, year=1989, url="http://ece-classweb.ucsd.edu/winter06/ece173/documents/Sanger%201989%20--%20Optimal%20Unsupervised%20Learning%20in%20a%20Single-layer%20Linear%20FeedforwardNN.pdf")
public class GeneralizedHebbianAlgorithm
extends AbstractAnytimeBatchLearner<Collection<Vector>, PrincipalComponentsAnalysisFunction>
implements PrincipalComponentsAnalysis {
    private double learningRate;
    private int numComponents;
    private PrincipalComponentsAnalysisFunction result;
    private ArrayList<Vector> components;
    private Vector mean;
    private double minChange;

    public GeneralizedHebbianAlgorithm(int numComponents, double learningRate, int maxIterations, double minChange) {
        super(maxIterations);
        this.setNumComponents(numComponents);
        this.setLearningRate(learningRate);
        this.setMinChange(minChange);
        this.setResult(null);
    }

    @Override
    public GeneralizedHebbianAlgorithm clone() {
        GeneralizedHebbianAlgorithm clone = (GeneralizedHebbianAlgorithm)super.clone();
        clone.setData(ObjectUtil.cloneSmartElementsAsArrayList((Collection)((Collection)this.getData())));
        clone.setResult((PrincipalComponentsAnalysisFunction)ObjectUtil.cloneSafe((CloneableSerializable)this.getResult()));
        clone.mean = (Vector)ObjectUtil.cloneSafe((CloneableSerializable)this.mean);
        return clone;
    }

    @Override
    protected boolean initializeAlgorithm() {
        boolean retval = true;
        this.setData(ObjectUtil.cloneSmartElementsAsArrayList((Collection)((Collection)this.data)));
        int M = this.getNumComponents();
        int N = ((Vector)((Collection)this.getData()).iterator().next()).getDimensionality();
        if (M > N) {
            retval = false;
            throw new IllegalArgumentException("Number of EigenVectors must be <= dimension of Vectors");
        }
        this.mean = (Vector)MultivariateStatisticsUtil.computeMean((Iterable)((Iterable)this.getData()));
        for (Vector x : (Collection)this.getData()) {
            x.minusEquals((Ring)this.mean);
        }
        this.components = new ArrayList(M);
        for (int i = 0; i < M; ++i) {
            Vector ui = VectorFactory.getDefault().createVector(N);
            ui.setElement(i, 1.0);
            this.components.add(ui);
        }
        return retval;
    }

    @Override
    protected void cleanupAlgorithm() {
        System.out.println("Stopping after " + this.getIteration());
        int N = ((Vector)((Collection)this.getData()).iterator().next()).getDimensionality();
        Matrix Umatrix = MatrixFactory.getDefault().createMatrix(this.getNumComponents(), N);
        for (int i = 0; i < this.getNumComponents(); ++i) {
            Vector ui = this.components.get(i);
            Umatrix.setRow(i, ui.unitVector());
        }
        this.setResult(new PrincipalComponentsAnalysisFunction(this.mean, new MatrixMultiplyVectorFunction(Umatrix)));
    }

    @Override
    protected boolean step() {
        boolean retval = true;
        ArrayList<Vector> componentCopy = new ArrayList<Vector>(this.getNumComponents());
        for (int i = 0; i < this.getNumComponents(); ++i) {
            componentCopy.add(this.components.get(i).clone());
        }
        double alpha = this.getLearningRate();
        for (Vector x : (Collection)this.getData()) {
            RingAccumulator sum = new RingAccumulator();
            for (int i = 0; i < this.getNumComponents(); ++i) {
                for (int j = 0; j <= i; ++j) {
                    Vector uj = this.components.get(j);
                    sum.accumulate(uj.scale(uj.dotProduct(x)));
                }
                Vector delta = (Vector)((Vector)x.minus(sum.getSum())).scale(this.components.get(i).dotProduct(x) * alpha);
                this.components.get(i).plusEquals((Ring)delta);
            }
        }
        double changeCurrent = 0.0;
        for (int i = 0; i < this.getNumComponents(); ++i) {
            changeCurrent += ((Vector)this.components.get(i).minus((Ring)componentCopy.get(i))).norm2();
        }
        double delta = changeCurrent /= alpha;
        if (Math.abs(delta) <= this.getMinChange() || Double.isNaN(delta) || Double.isInfinite(delta)) {
            retval = false;
        }
        System.out.println(this.getIteration() + ": Change = " + delta + ", Alpha = " + alpha);
        return retval;
    }

    public double getLearningRate() {
        return this.learningRate;
    }

    public void setLearningRate(double learningRate) {
        if (learningRate <= 0.0 || learningRate > 1.0) {
            throw new IllegalArgumentException("LearningRate must be (0,1]");
        }
        this.learningRate = learningRate;
    }

    public double getMinChange() {
        return this.minChange;
    }

    public void setMinChange(double minChange) {
        if (minChange < 0.0) {
            throw new IllegalArgumentException("minChange must be greater than or equal to zero");
        }
        this.minChange = minChange;
    }

    @Override
    public int getNumComponents() {
        return this.numComponents;
    }

    public void setNumComponents(int numComponents) {
        if (numComponents <= 0) {
            throw new IllegalArgumentException("Number of components must be > 0");
        }
        this.numComponents = numComponents;
    }

    @Override
    public PrincipalComponentsAnalysisFunction getResult() {
        return this.result;
    }

    protected void setResult(PrincipalComponentsAnalysisFunction result) {
        this.result = result;
    }
}

