package gov.sandia.cognition.learning.algorithm.confidence;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.learning.algorithm.AbstractSupervisedBatchAndIncrementalLearner;
import gov.sandia.cognition.learning.function.categorization.DiagonalConfidenceWeightedBinaryCategorizer;
import gov.sandia.cognition.math.matrix.DiagonalMatrix;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.statistics.distribution.UnivariateGaussian;
import gov.sandia.cognition.util.ArgumentChecker;
import net.sf.saxon.trace.Location;

@PublicationReference(author = {"Koby Crammer", "Mark Dredze", "Fernando Pereira"}, title = "Exact Convex Confidence-Weighted Learning", year = Location.TEMPLATE, type = PublicationType.Conference, publication = "Advances in Neural Information Processing Systems", url = "http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.169.3364")
/* loaded from: input_file:gov/sandia/cognition/learning/algorithm/confidence/ConfidenceWeightedDiagonalDeviation.class */
public class ConfidenceWeightedDiagonalDeviation extends AbstractSupervisedBatchAndIncrementalLearner<Vectorizable, Boolean, DiagonalConfidenceWeightedBinaryCategorizer> {
    public static final double DEFAULT_CONFIDENCE = 0.85d;
    public static final double DEFAULT_DEFAULT_VARIANCE = 1.0d;
    protected double confidence;
    protected double defaultVariance;
    protected double phi;
    protected double psi;
    protected double epsilon;

    public ConfidenceWeightedDiagonalDeviation() {
        this(0.85d, 1.0d);
    }

    public ConfidenceWeightedDiagonalDeviation(double d, double d2) {
        setConfidence(d);
        setDefaultVariance(d2);
    }

    @Override // gov.sandia.cognition.learning.algorithm.IncrementalLearner
    public DiagonalConfidenceWeightedBinaryCategorizer createInitialLearnedObject() {
        return new DiagonalConfidenceWeightedBinaryCategorizer();
    }

    @Override // gov.sandia.cognition.learning.algorithm.SupervisedIncrementalLearner
    public void update(DiagonalConfidenceWeightedBinaryCategorizer diagonalConfidenceWeightedBinaryCategorizer, Vectorizable vectorizable, Boolean bool) {
        if (vectorizable == null || bool == null) {
            return;
        }
        update(diagonalConfidenceWeightedBinaryCategorizer, vectorizable.convertToVector(), bool.booleanValue());
    }

    public void update(DiagonalConfidenceWeightedBinaryCategorizer diagonalConfidenceWeightedBinaryCategorizer, Vector vector, boolean z) {
        Vector mean;
        Vector variance;
        if (diagonalConfidenceWeightedBinaryCategorizer.isInitialized()) {
            mean = diagonalConfidenceWeightedBinaryCategorizer.getMean();
            variance = diagonalConfidenceWeightedBinaryCategorizer.getVariance();
        } else {
            int dimensionality = vector.getDimensionality();
            mean = VectorFactory.getDenseDefault().createVector(dimensionality);
            variance = VectorFactory.getDenseDefault().createVector(dimensionality, getDefaultVariance());
            diagonalConfidenceWeightedBinaryCategorizer.setMean(mean);
            diagonalConfidenceWeightedBinaryCategorizer.setVariance(variance);
        }
        double dotProduct = vector.dotProduct(mean);
        double d = z ? 1.0d : -1.0d;
        double d2 = d * dotProduct;
        Vector dotTimes = vector.dotTimes(variance);
        double dotProduct2 = vector.dotProduct(dotTimes);
        if (dotProduct2 == 0.0d || d2 > this.phi * Math.sqrt(dotProduct2)) {
            return;
        }
        double max = Math.max((((-d2) * this.psi) + Math.sqrt((((d2 * d2) * Math.pow(this.phi, 4.0d)) / 4.0d) + (((dotProduct2 * this.phi) * this.phi) * this.epsilon))) / (dotProduct2 * this.epsilon), 0.0d);
        if (max <= 0.0d) {
            return;
        }
        double sqrt = (max * this.phi) / (Math.sqrt(0.25d * Math.pow((((-max) * dotProduct2) * this.phi) + Math.sqrt((((((max * max) * dotProduct2) * dotProduct2) * this.phi) * this.phi) + (4.0d * dotProduct2)), 2.0d)) + ((dotProduct2 * max) * this.phi));
        mean.plusEquals(dotTimes.scale(d * max));
        DiagonalMatrix createDiagonal = MatrixFactory.getDiagonalDefault().createDiagonal(vector.dotTimes(vector));
        createDiagonal.scaleEquals(sqrt);
        Matrix inverse = diagonalConfidenceWeightedBinaryCategorizer.getCovariance().inverse();
        inverse.plusEquals(createDiagonal);
        Matrix inverse2 = inverse.inverse();
        for (int i = 0; i < variance.getDimensionality(); i++) {
            variance.setElement(i, inverse2.getElement(i, i));
        }
        diagonalConfidenceWeightedBinaryCategorizer.setMean(mean);
        diagonalConfidenceWeightedBinaryCategorizer.setVariance(variance);
    }

    public double getConfidence() {
        return this.confidence;
    }

    public void setConfidence(double d) {
        ArgumentChecker.assertIsInRangeInclusive("confidence", d, 0.5d, 1.0d);
        this.confidence = d;
        this.phi = -UnivariateGaussian.CDF.Inverse.evaluate(1.0d - d, 0.0d, 1.0d);
        this.psi = 1.0d + ((this.phi * this.phi) / 2.0d);
        this.epsilon = 1.0d + (this.phi * this.phi);
    }

    public double getDefaultVariance() {
        return this.defaultVariance;
    }

    public void setDefaultVariance(double d) {
        ArgumentChecker.assertIsPositive("defaultVariance", d);
        this.defaultVariance = d;
    }
}
