package gov.sandia.cognition.learning.algorithm.perceptron.kernel;

import gov.sandia.cognition.algorithm.MeasurablePerformanceAlgorithm;
import gov.sandia.cognition.annotation.CodeReview;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.learning.algorithm.AbstractAnytimeSupervisedBatchLearner;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.categorization.DefaultKernelBinaryCategorizer;
import gov.sandia.cognition.learning.function.kernel.Kernel;
import gov.sandia.cognition.util.DefaultNamedValue;
import gov.sandia.cognition.util.DefaultWeightedValue;
import gov.sandia.cognition.util.NamedValue;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedHashMap;

@CodeReview(reviewer = {"Kevin R. Dixon"}, date = "2008-07-23", changesNeeded = false, comments = {"Added PublicationReference to the original article.", "Minor changes to javadoc.", "Looks fine."})
@PublicationReference(author = {"Yoav Freund", "Robert E. Schapire"}, title = "Large margin classification using the perceptron algorithm", publication = "Machine Learning", type = PublicationType.Journal, year = 1999, notes = {"Volume 37, Number 3"}, pages = {277, 296}, url = "http://www.cs.ucsd.edu/~yfreund/papers/LargeMarginsUsingPerceptron.pdf")
/* loaded from: input_file:gov/sandia/cognition/learning/algorithm/perceptron/kernel/KernelPerceptron.class */
public class KernelPerceptron<InputType> extends AbstractAnytimeSupervisedBatchLearner<InputType, Boolean, DefaultKernelBinaryCategorizer<InputType>> implements MeasurablePerformanceAlgorithm {
    public static final int DEFAULT_MAX_ITERATIONS = 100;
    public static final double DEFAULT_MARGIN_POSITIVE = 0.0d;
    public static final double DEFAULT_MARGIN_NEGATIVE = 0.0d;
    private Kernel<? super InputType> kernel;
    private double marginPositive;
    private double marginNegative;
    private DefaultKernelBinaryCategorizer<InputType> result;
    private int errorCount;
    private LinkedHashMap<InputOutputPair<? extends InputType, ? extends Boolean>, DefaultWeightedValue<InputType>> supportsMap;

    public KernelPerceptron() {
        this(null);
    }

    public KernelPerceptron(Kernel<? super InputType> kernel) {
        this(kernel, 100);
    }

    public KernelPerceptron(Kernel<? super InputType> kernel, int i) {
        this(kernel, i, 0.0d, 0.0d);
    }

    public KernelPerceptron(Kernel<? super InputType> kernel, int i, double d, double d2) {
        super(i);
        setKernel(kernel);
        setMarginPositive(d);
        setMarginNegative(d2);
        setResult(null);
        setErrorCount(0);
        setSupportsMap(null);
    }

    @Override // gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
    protected boolean initializeAlgorithm() {
        if (getData() == null) {
            return false;
        }
        int i = 0;
        Iterator it = getData().iterator();
        while (it.hasNext()) {
            if (((InputOutputPair) it.next()) != null) {
                i++;
            }
        }
        if (i <= 0) {
            return false;
        }
        setErrorCount(i);
        setSupportsMap(new LinkedHashMap<>());
        setResult(new DefaultKernelBinaryCategorizer<>(getKernel(), getSupportsMap().values(), 0.0d));
        return true;
    }

    @Override // gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
    protected boolean step() {
        double d;
        double d2;
        setErrorCount(0);
        for (InputOutputPair<? extends InputType, ? extends Boolean> inputOutputPair : getData()) {
            if (inputOutputPair != null) {
                InputType input = inputOutputPair.getInput();
                boolean booleanValue = ((Boolean) inputOutputPair.getOutput()).booleanValue();
                double evaluateAsDouble = this.result.evaluateAsDouble(input);
                if ((booleanValue && evaluateAsDouble <= this.marginPositive) || (!booleanValue && evaluateAsDouble >= (-this.marginNegative))) {
                    setErrorCount(getErrorCount() + 1);
                    double bias = this.result.getBias();
                    DefaultWeightedValue<InputType> defaultWeightedValue = this.supportsMap.get(inputOutputPair);
                    double weight = defaultWeightedValue != null ? defaultWeightedValue.getWeight() : 0.0d;
                    if (booleanValue) {
                        d = weight + 1.0d;
                        d2 = bias + 1.0d;
                    } else {
                        d = weight - 1.0d;
                        d2 = bias - 1.0d;
                    }
                    if (defaultWeightedValue == null) {
                        this.supportsMap.put(inputOutputPair, new DefaultWeightedValue<>(input, d));
                    } else if (d == 0.0d) {
                        this.supportsMap.remove(inputOutputPair);
                    } else {
                        defaultWeightedValue.setWeight(d);
                    }
                    this.result.setBias(d2);
                }
            }
        }
        return getErrorCount() > 0;
    }

    @Override // gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
    protected void cleanupAlgorithm() {
        if (getSupportsMap() != null) {
            getResult().setExamples(new ArrayList(getSupportsMap().values()));
            setSupportsMap(null);
        }
    }

    public Kernel<? super InputType> getKernel() {
        return this.kernel;
    }

    public void setKernel(Kernel<? super InputType> kernel) {
        this.kernel = kernel;
    }

    public void setMargin(double d) {
        setMarginPositive(d);
        setMarginNegative(d);
    }

    public double getMarginPositive() {
        return this.marginPositive;
    }

    public void setMarginPositive(double d) {
        this.marginPositive = d;
    }

    public double getMarginNegative() {
        return this.marginNegative;
    }

    public void setMarginNegative(double d) {
        this.marginNegative = d;
    }

    @Override // gov.sandia.cognition.algorithm.AnytimeAlgorithm
    public DefaultKernelBinaryCategorizer<InputType> getResult() {
        return this.result;
    }

    protected void setResult(DefaultKernelBinaryCategorizer<InputType> defaultKernelBinaryCategorizer) {
        this.result = defaultKernelBinaryCategorizer;
    }

    public int getErrorCount() {
        return this.errorCount;
    }

    protected void setErrorCount(int i) {
        this.errorCount = i;
    }

    protected LinkedHashMap<InputOutputPair<? extends InputType, ? extends Boolean>, DefaultWeightedValue<InputType>> getSupportsMap() {
        return this.supportsMap;
    }

    protected void setSupportsMap(LinkedHashMap<InputOutputPair<? extends InputType, ? extends Boolean>, DefaultWeightedValue<InputType>> linkedHashMap) {
        this.supportsMap = linkedHashMap;
    }

    @Override // gov.sandia.cognition.algorithm.MeasurablePerformanceAlgorithm
    public NamedValue<Integer> getPerformance() {
        return new DefaultNamedValue("error count", Integer.valueOf(getErrorCount()));
    }
}
