/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.learning.function.vector;

import gov.sandia.cognition.learning.algorithm.gradient.GradientDescendable;
import gov.sandia.cognition.learning.function.vector.DifferentiableSquashedMatrixMultiplyVectorFunction;
import gov.sandia.cognition.learning.function.vector.ElementWiseDifferentiableVectorFunction;
import gov.sandia.cognition.learning.function.vector.FeedforwardNeuralNetwork;
import gov.sandia.cognition.learning.function.vector.MatrixMultiplyVectorFunction;
import gov.sandia.cognition.math.DifferentiableUnivariateScalarFunction;
import gov.sandia.cognition.math.matrix.DifferentiableVectorFunction;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Random;

public class DifferentiableFeedforwardNeuralNetwork
extends FeedforwardNeuralNetwork
implements GradientDescendable {
    public DifferentiableFeedforwardNeuralNetwork(ArrayList<Integer> nodesPerLayer, ArrayList<DifferentiableUnivariateScalarFunction> layerActivationFunctions, Random random) {
        super(new ArrayList());
        ArrayList<DifferentiableSquashedMatrixMultiplyVectorFunction> layers = new ArrayList<DifferentiableSquashedMatrixMultiplyVectorFunction>(layerActivationFunctions.size());
        double range = 0.1;
        for (int i = 0; i < nodesPerLayer.size() - 1; ++i) {
            int currentNum = nodesPerLayer.get(i);
            int nextNum = nodesPerLayer.get(i + 1);
            Matrix w = MatrixFactory.getDefault().createUniformRandom(nextNum, currentNum, -0.1, 0.1, random);
            DifferentiableSquashedMatrixMultiplyVectorFunction layer = new DifferentiableSquashedMatrixMultiplyVectorFunction(new MatrixMultiplyVectorFunction(w), layerActivationFunctions.get(i));
            layers.add(layer);
        }
        this.setLayers(layers);
    }

    public DifferentiableFeedforwardNeuralNetwork(int numInputs, int numHiddens, int numOutputs, DifferentiableVectorFunction activationFunction, Random random) {
        super(new ArrayList());
        ArrayList<DifferentiableSquashedMatrixMultiplyVectorFunction> layers = new ArrayList<DifferentiableSquashedMatrixMultiplyVectorFunction>(2);
        double range = 1.0;
        Matrix w12 = MatrixFactory.getDefault().createUniformRandom(numHiddens, numInputs, -1.0, 1.0, random);
        Matrix w23 = MatrixFactory.getDefault().createUniformRandom(numOutputs, numHiddens, -1.0, 1.0, random);
        layers.add(new DifferentiableSquashedMatrixMultiplyVectorFunction(new MatrixMultiplyVectorFunction(w12), activationFunction));
        layers.add(new DifferentiableSquashedMatrixMultiplyVectorFunction(new MatrixMultiplyVectorFunction(w23), activationFunction));
        this.setLayers(layers);
    }

    public DifferentiableFeedforwardNeuralNetwork(int numInputs, int numHiddens, int numOutputs, DifferentiableUnivariateScalarFunction scalarFunction, Random random) {
        this(numInputs, numHiddens, numOutputs, new ElementWiseDifferentiableVectorFunction(scalarFunction), random);
    }

    public DifferentiableFeedforwardNeuralNetwork(DifferentiableSquashedMatrixMultiplyVectorFunction ... layers) {
        super(new ArrayList<DifferentiableSquashedMatrixMultiplyVectorFunction>(Arrays.asList(layers)));
    }

    @Override
    public DifferentiableFeedforwardNeuralNetwork clone() {
        return (DifferentiableFeedforwardNeuralNetwork)super.clone();
    }

    public ArrayList<DifferentiableSquashedMatrixMultiplyVectorFunction> getLayers() {
        return super.getLayers();
    }

    @Override
    public Matrix computeParameterGradient(Vector input) {
        Matrix layerGradient;
        int numLayers = this.getLayers().size();
        ArrayList<Vector> layerActivations = this.evaluateAtEachLayer(input);
        ArrayList<Matrix> layerGradients = new ArrayList<Matrix>(numLayers);
        int M = layerActivations.get(numLayers).getDimensionality();
        int N = 0;
        Matrix layerDerivative = MatrixFactory.getDefault().createIdentity(M, M);
        for (int i = numLayers - 1; i >= 0; --i) {
            DifferentiableSquashedMatrixMultiplyVectorFunction layer = this.getLayers().get(i);
            Vector layerInput = layerActivations.get(i);
            layerGradient = layerDerivative.times(layer.computeParameterGradient(layerInput));
            N += layerGradient.getNumColumns();
            layerGradients.add(layerGradient);
            if (i <= 0) continue;
            layerDerivative = layerDerivative.times(layer.differentiate(layerInput));
        }
        Matrix gradient = MatrixFactory.getDefault().createMatrix(M, N);
        int columnIndex = 0;
        for (int n = numLayers - 1; n >= 0; --n) {
            layerGradient = (Matrix)layerGradients.get(n);
            if (n == 0) {
                int row = 0;
                int Mi = layerGradient.getNumRows();
                int Ni = layerGradient.getNumColumns();
                for (int column = 0; column < Ni; ++column) {
                    double value = layerGradient.getElement(row, column);
                    gradient.setElement(row, columnIndex + column, value);
                    row = (row + 1) % Mi;
                }
            } else {
                gradient.setSubMatrix(0, columnIndex, layerGradient);
            }
            columnIndex += layerGradient.getNumColumns();
        }
        return gradient;
    }
}

