/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.learning.algorithm.perceptron.kernel;

import gov.sandia.cognition.algorithm.MeasurablePerformanceAlgorithm;
import gov.sandia.cognition.annotation.CodeReview;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.learning.algorithm.AbstractAnytimeSupervisedBatchLearner;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.categorization.DefaultKernelBinaryCategorizer;
import gov.sandia.cognition.learning.function.categorization.KernelBinaryCategorizer;
import gov.sandia.cognition.learning.function.kernel.Kernel;
import gov.sandia.cognition.util.DefaultNamedValue;
import gov.sandia.cognition.util.DefaultWeightedValue;
import gov.sandia.cognition.util.NamedValue;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashMap;

@CodeReview(reviewer={"Kevin R. Dixon"}, date="2008-07-23", changesNeeded=false, comments={"Added PublicationReference to the original article.", "Minor changes to javadoc.", "Looks fine."})
@PublicationReference(author={"Yoav Freund", "Robert E. Schapire"}, title="Large margin classification using the perceptron algorithm", publication="Machine Learning", type=PublicationType.Journal, year=1999, notes={"Volume 37, Number 3"}, pages={277, 296}, url="http://www.cs.ucsd.edu/~yfreund/papers/LargeMarginsUsingPerceptron.pdf")
public class KernelPerceptron<InputType>
extends AbstractAnytimeSupervisedBatchLearner<InputType, Boolean, DefaultKernelBinaryCategorizer<InputType>>
implements MeasurablePerformanceAlgorithm {
    public static final int DEFAULT_MAX_ITERATIONS = 100;
    public static final double DEFAULT_MARGIN_POSITIVE = 0.0;
    public static final double DEFAULT_MARGIN_NEGATIVE = 0.0;
    private Kernel<? super InputType> kernel;
    private double marginPositive;
    private double marginNegative;
    private DefaultKernelBinaryCategorizer<InputType> result;
    private int errorCount;
    private LinkedHashMap<InputOutputPair<? extends InputType, ? extends Boolean>, DefaultWeightedValue<InputType>> supportsMap;

    public KernelPerceptron() {
        this(null);
    }

    public KernelPerceptron(Kernel<? super InputType> kernel) {
        this(kernel, 100);
    }

    public KernelPerceptron(Kernel<? super InputType> kernel, int maxIterations) {
        this(kernel, maxIterations, 0.0, 0.0);
    }

    public KernelPerceptron(Kernel<? super InputType> kernel, int maxIterations, double marginPositive, double marginNegative) {
        super(maxIterations);
        this.setKernel(kernel);
        this.setMarginPositive(marginPositive);
        this.setMarginNegative(marginNegative);
        this.setResult(null);
        this.setErrorCount(0);
        this.setSupportsMap(null);
    }

    @Override
    protected boolean initializeAlgorithm() {
        if (this.getData() == null) {
            return false;
        }
        int validCount = 0;
        for (InputOutputPair example : (Collection)this.getData()) {
            if (example == null) continue;
            ++validCount;
        }
        if (validCount <= 0) {
            return false;
        }
        this.setErrorCount(validCount);
        this.setSupportsMap(new LinkedHashMap<InputOutputPair<? extends InputType, ? extends Boolean>, DefaultWeightedValue<InputType>>());
        this.setResult(new DefaultKernelBinaryCategorizer<InputType>(this.getKernel(), this.getSupportsMap().values(), 0.0));
        return true;
    }

    @Override
    protected boolean step() {
        this.setErrorCount(0);
        for (InputOutputPair example : (Collection)this.getData()) {
            if (example == null) continue;
            Object input = example.getInput();
            boolean actual = (Boolean)example.getOutput();
            double prediction = this.result.evaluateAsDouble(input);
            if (!(actual && prediction <= this.marginPositive) && (actual || !(prediction >= -this.marginNegative))) continue;
            this.setErrorCount(this.getErrorCount() + 1);
            double weight = 0.0;
            double bias = this.result.getBias();
            DefaultWeightedValue support = this.supportsMap.get(example);
            if (support != null) {
                weight = support.getWeight();
            }
            if (actual) {
                weight += 1.0;
                bias += 1.0;
            } else {
                weight -= 1.0;
                bias -= 1.0;
            }
            if (support == null) {
                support = new DefaultWeightedValue(input, weight);
                this.supportsMap.put(example, support);
            } else if (weight == 0.0) {
                this.supportsMap.remove(example);
            } else {
                support.setWeight(weight);
            }
            this.result.setBias(bias);
        }
        return this.getErrorCount() > 0;
    }

    @Override
    protected void cleanupAlgorithm() {
        if (this.getSupportsMap() != null) {
            ((KernelBinaryCategorizer)this.getResult()).setExamples(new ArrayList<DefaultWeightedValue<InputType>>(this.getSupportsMap().values()));
            this.setSupportsMap(null);
        }
    }

    public Kernel<? super InputType> getKernel() {
        return this.kernel;
    }

    public void setKernel(Kernel<? super InputType> kernel) {
        this.kernel = kernel;
    }

    public void setMargin(double margin) {
        this.setMarginPositive(margin);
        this.setMarginNegative(margin);
    }

    public double getMarginPositive() {
        return this.marginPositive;
    }

    public void setMarginPositive(double marginPositive) {
        this.marginPositive = marginPositive;
    }

    public double getMarginNegative() {
        return this.marginNegative;
    }

    public void setMarginNegative(double marginNegative) {
        this.marginNegative = marginNegative;
    }

    public DefaultKernelBinaryCategorizer<InputType> getResult() {
        return this.result;
    }

    protected void setResult(DefaultKernelBinaryCategorizer<InputType> result) {
        this.result = result;
    }

    public int getErrorCount() {
        return this.errorCount;
    }

    protected void setErrorCount(int errorCount) {
        this.errorCount = errorCount;
    }

    protected LinkedHashMap<InputOutputPair<? extends InputType, ? extends Boolean>, DefaultWeightedValue<InputType>> getSupportsMap() {
        return this.supportsMap;
    }

    protected void setSupportsMap(LinkedHashMap<InputOutputPair<? extends InputType, ? extends Boolean>, DefaultWeightedValue<InputType>> supportsMap) {
        this.supportsMap = supportsMap;
    }

    public NamedValue<Integer> getPerformance() {
        return new DefaultNamedValue("error count", (Object)this.getErrorCount());
    }
}

