package gov.sandia.cognition.learning.algorithm.regression;

import gov.sandia.cognition.annotation.CodeReview;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.learning.algorithm.SupervisedBatchLearner;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.scalar.VectorFunctionLinearDiscriminant;
import gov.sandia.cognition.learning.function.vector.ScalarBasisSet;
import gov.sandia.cognition.math.UnivariateStatisticsUtil;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.statistics.distribution.ChiSquareDistribution;
import gov.sandia.cognition.statistics.method.AbstractConfidenceStatistic;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import net.sf.saxon.trace.Location;

@CodeReview(reviewer = {"Kevin R. Dixon"}, date = "2008-09-02", changesNeeded = false, comments = {"Made minor changes to javadoc", "Looks fine."})
@PublicationReference(author = {"Wikipedia"}, title = "Linear regression", type = PublicationType.WebPage, year = Location.TEMPLATE, url = "http://en.wikipedia.org/wiki/Linear_regression")
/* loaded from: input_file:gov/sandia/cognition/learning/algorithm/regression/LinearRegression.class */
public class LinearRegression<InputType> extends AbstractCloneableSerializable implements SupervisedBatchLearner<InputType, Double, VectorFunctionLinearDiscriminant<InputType>> {
    public static final double DEFAULT_PSEUDO_INVERSE_TOLERANCE = 1.0E-10d;
    private VectorFunctionLinearDiscriminant<InputType> learned;
    private Evaluator<? super InputType, Vector> inputToVectorMap;
    private boolean usePseudoInverse;

    /* loaded from: input_file:gov/sandia/cognition/learning/algorithm/regression/LinearRegression$Statistic.class */
    public static class Statistic extends AbstractConfidenceStatistic {
        private double chiSquare;
        private double rootMeanSquaredError;
        private double meanL1Error;
        private double targetEstimateCorrelation;
        private double unpredictedErrorFraction;
        private int numSamples;
        private int numParameters;
        private double degreesOfFreedom;

        public Statistic(Collection<Double> collection, Collection<Double> collection2, int i) {
            super(0.0d);
            computeStatistics(collection, collection2, Collections.nCopies(collection.size(), new Double(1.0d)), i);
        }

        public Statistic(Collection<Double> collection, Collection<Double> collection2, Collection<Double> collection3, int i) {
            super(0.0d);
            computeStatistics(collection, collection2, collection3, i);
        }

        private Statistic(Statistic statistic) {
            super(statistic.getNullHypothesisProbability());
            setDegreesOfFreedom(statistic.getDegreesOfFreedom());
            setMeanL1Error(statistic.getMeanL1Error());
            setNumParameters(statistic.getNumParameters());
            setNumSamples(statistic.getNumSamples());
            setRootMeanSquaredError(statistic.getRootMeanSquaredError());
            setTargetEstimateCorrelation(statistic.getTargetEstimateCorrelation());
            setUnpredictedErrorFraction(statistic.getUnpredictedErrorFraction());
        }

        @Override // gov.sandia.cognition.util.AbstractCloneableSerializable
        /* renamed from: clone */
        public Statistic mo811clone() {
            return (Statistic) super.mo811clone();
        }

        private void computeStatistics(Collection<Double> collection, Collection<Double> collection2, Collection<Double> collection3, int i) {
            if (collection.size() != collection2.size() && collection.size() != collection3.size()) {
                throw new IllegalArgumentException("Targets, Estimates, and Weights must be the same size!");
            }
            int size = collection.size();
            ArrayList arrayList = new ArrayList(size);
            double d = 0.0d;
            double d2 = 0.0d;
            Iterator<Double> it = collection.iterator();
            Iterator<Double> it2 = collection2.iterator();
            Iterator<Double> it3 = collection3.iterator();
            for (int i2 = 0; i2 < size; i2++) {
                double doubleValue = it2.next().doubleValue();
                double doubleValue2 = it.next().doubleValue();
                double doubleValue3 = it3.next().doubleValue();
                double d3 = doubleValue3 * (doubleValue2 - doubleValue);
                arrayList.add(Double.valueOf(d3));
                d += Math.abs(d3);
                d2 += doubleValue3;
            }
            double d4 = d2 > 0.0d ? d / d2 : 0.0d;
            double d5 = size - i;
            if (d5 < 1.0d) {
                d5 = 1.0d;
            }
            double computeSumSquaredDifference = UnivariateStatisticsUtil.computeSumSquaredDifference(arrayList, 0.0d);
            double evaluate = 1.0d - ChiSquareDistribution.CDF.evaluate(computeSumSquaredDifference, d5);
            double computeRootMeanSquaredError = UnivariateStatisticsUtil.computeRootMeanSquaredError(arrayList, 0.0d);
            double computeCorrelation = UnivariateStatisticsUtil.computeCorrelation(collection, collection2);
            setNullHypothesisProbability(evaluate);
            setChiSquare(computeSumSquaredDifference);
            setDegreesOfFreedom(d5);
            setMeanL1Error(d4);
            setNumSamples(size);
            setRootMeanSquaredError(computeRootMeanSquaredError);
            setTargetEstimateCorrelation(computeCorrelation);
            setUnpredictedErrorFraction(1.0d - (computeCorrelation * computeCorrelation));
            setNumParameters(i);
        }

        public double getRootMeanSquaredError() {
            return this.rootMeanSquaredError;
        }

        protected void setRootMeanSquaredError(double d) {
            this.rootMeanSquaredError = d;
        }

        public double getTargetEstimateCorrelation() {
            return this.targetEstimateCorrelation;
        }

        protected void setTargetEstimateCorrelation(double d) {
            this.targetEstimateCorrelation = d;
        }

        public double getUnpredictedErrorFraction() {
            return this.unpredictedErrorFraction;
        }

        protected void setUnpredictedErrorFraction(double d) {
            this.unpredictedErrorFraction = d;
        }

        public int getNumSamples() {
            return this.numSamples;
        }

        protected void setNumSamples(int i) {
            this.numSamples = i;
        }

        public double getDegreesOfFreedom() {
            return this.degreesOfFreedom;
        }

        protected void setDegreesOfFreedom(double d) {
            this.degreesOfFreedom = d;
        }

        public double getMeanL1Error() {
            return this.meanL1Error;
        }

        protected void setMeanL1Error(double d) {
            this.meanL1Error = d;
        }

        public int getNumParameters() {
            return this.numParameters;
        }

        public void setNumParameters(int i) {
            this.numParameters = i;
        }

        public double getChiSquare() {
            return this.chiSquare;
        }

        public void setChiSquare(double d) {
            this.chiSquare = d;
        }

        @Override // gov.sandia.cognition.statistics.method.ConfidenceStatistic
        public double getTestStatistic() {
            return getChiSquare();
        }
    }

    public LinearRegression(Evaluator<? super InputType, Double>... evaluatorArr) {
        this(Arrays.asList(evaluatorArr));
    }

    public LinearRegression(Collection<? extends Evaluator<? super InputType, Double>> collection) {
        this(new ScalarBasisSet(collection));
    }

    public LinearRegression(ScalarBasisSet<InputType> scalarBasisSet) {
        this((Evaluator) scalarBasisSet);
    }

    public LinearRegression(Evaluator<? super InputType, Vector> evaluator) {
        setInputToVectorMap(evaluator);
        setUsePseudoInverse(true);
        setLearned(null);
    }

    @Override // gov.sandia.cognition.util.AbstractCloneableSerializable
    /* renamed from: clone */
    public LinearRegression<InputType> mo811clone() {
        LinearRegression<InputType> linearRegression = (LinearRegression) super.mo811clone();
        linearRegression.setInputToVectorMap((Evaluator) ObjectUtil.cloneSmart(getInputToVectorMap()));
        linearRegression.setLearned((VectorFunctionLinearDiscriminant) ObjectUtil.cloneSafe(getLearned()));
        return linearRegression;
    }

    public VectorFunctionLinearDiscriminant<InputType> getLearned() {
        return this.learned;
    }

    protected void setLearned(VectorFunctionLinearDiscriminant<InputType> vectorFunctionLinearDiscriminant) {
        this.learned = vectorFunctionLinearDiscriminant;
    }

    @Override // gov.sandia.cognition.learning.algorithm.BatchLearner
    public VectorFunctionLinearDiscriminant<InputType> learn(Collection<? extends InputOutputPair<? extends InputType, Double>> collection) {
        setLearned(null);
        int dimensionality = this.inputToVectorMap.evaluate(collection.iterator().next().getInput()).getDimensionality();
        int size = collection.size();
        Matrix createMatrix = MatrixFactory.getDefault().createMatrix(size, dimensionality);
        Vector createVector = VectorFactory.getDefault().createVector(size);
        int i = 0;
        for (InputOutputPair<? extends InputType, Double> inputOutputPair : collection) {
            double weight = DatasetUtil.getWeight(inputOutputPair);
            createVector.setElement(i, inputOutputPair.getOutput().doubleValue() * weight);
            createMatrix.setRow(i, this.inputToVectorMap.evaluate(inputOutputPair.getInput()).scale(weight));
            i++;
        }
        setLearned(new VectorFunctionLinearDiscriminant<>(this.inputToVectorMap, getUsePseudoInverse() ? createMatrix.pseudoInverse(1.0E-10d).times(createVector) : createMatrix.solve(createVector)));
        return getLearned();
    }

    public Evaluator<? super InputType, Vector> getInputToVectorMap() {
        return this.inputToVectorMap;
    }

    public void setInputToVectorMap(Evaluator<? super InputType, Vector> evaluator) {
        this.inputToVectorMap = evaluator;
    }

    public boolean getUsePseudoInverse() {
        return this.usePseudoInverse;
    }

    public void setUsePseudoInverse(boolean z) {
        this.usePseudoInverse = z;
    }
}
