package gov.sandia.cognition.statistics.bayesian;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationReferences;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.learning.algorithm.IncrementalLearner;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.math.RingAccumulator;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.statistics.AbstractSufficientStatistic;
import gov.sandia.cognition.statistics.Distribution;
import gov.sandia.cognition.statistics.distribution.MultivariateGaussian;
import gov.sandia.cognition.statistics.distribution.UnivariateGaussian;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.CloneableSerializable;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.Collection;
import net.sf.saxon.trace.Location;

@PublicationReferences(references = {@PublicationReference(author = {"Christopher M. Bishop"}, title = "Pattern Recognition and Machine Learning", type = PublicationType.Book, year = 2006, pages = {152, 159}), @PublicationReference(author = {"Hanna M. Wallach"}, title = "Introduction to Gaussian Process Regression", type = PublicationType.Misc, year = 2005, url = "http://www.cs.umass.edu/~wallach/talks/gp_intro.pdf"), @PublicationReference(author = {"Wikipedia"}, title = "Bayesian linear regression", type = PublicationType.WebPage, year = Location.BUILT_IN_TEMPLATE, url = "http://en.wikipedia.org/wiki/Bayesian_linear_regression")})
/* loaded from: input_file:gov/sandia/cognition/statistics/bayesian/BayesianLinearRegression.class */
public class BayesianLinearRegression<InputType> extends AbstractBayesianRegression<InputType, Double, MultivariateGaussian> {
    public static final double DEFAULT_OUTPUT_VARIANCE = 1.0d;
    public static final double DEFAULT_WEIGHT_VARIANCE = 1.0d;
    protected double outputVariance;
    protected MultivariateGaussian weightPrior;

    /* loaded from: input_file:gov/sandia/cognition/statistics/bayesian/BayesianLinearRegression$IncrementalEstimator.class */
    public static class IncrementalEstimator<InputType> extends BayesianLinearRegression<InputType> implements IncrementalLearner<InputOutputPair<? extends InputType, Double>, IncrementalEstimator<InputType>.SufficientStatistic> {

        /* loaded from: input_file:gov/sandia/cognition/statistics/bayesian/BayesianLinearRegression$IncrementalEstimator$SufficientStatistic.class */
        public class SufficientStatistic extends AbstractSufficientStatistic<InputOutputPair<? extends InputType, Double>, MultivariateGaussian> {
            private Vector z;
            private Matrix covarianceInverse;

            public SufficientStatistic(MultivariateGaussian multivariateGaussian) {
                if (multivariateGaussian != null) {
                    this.covarianceInverse = multivariateGaussian.getCovarianceInverse().mo784clone();
                    this.z = this.covarianceInverse.times(multivariateGaussian.getMean());
                    this.count = 1L;
                } else {
                    this.covarianceInverse = null;
                    this.z = null;
                    this.count = 0L;
                }
            }

            @Override // gov.sandia.cognition.statistics.SufficientStatistic
            public void update(InputOutputPair<? extends InputType, Double> inputOutputPair) {
                this.count++;
                Vector evaluate = IncrementalEstimator.this.featureMap.evaluate(inputOutputPair.getInput());
                Vector mo784clone = evaluate.mo784clone();
                double doubleValue = inputOutputPair.getOutput().doubleValue();
                double weight = DatasetUtil.getWeight(inputOutputPair) / IncrementalEstimator.this.outputVariance;
                if (weight != 1.0d) {
                    mo784clone.scaleEquals(weight);
                }
                if (this.covarianceInverse == null) {
                    this.covarianceInverse = evaluate.outerProduct(mo784clone);
                } else {
                    this.covarianceInverse.plusEquals(evaluate.outerProduct(mo784clone));
                }
                if (doubleValue != 1.0d) {
                    mo784clone.scaleEquals(doubleValue);
                }
                if (this.z == null) {
                    this.z = mo784clone;
                } else {
                    this.z.plusEquals(mo784clone);
                }
            }

            @Override // gov.sandia.cognition.factory.Factory
            public MultivariateGaussian.PDF create() {
                MultivariateGaussian.PDF pdf = new MultivariateGaussian.PDF(getDimensionality());
                create((MultivariateGaussian) pdf);
                return pdf;
            }

            @Override // gov.sandia.cognition.statistics.SufficientStatistic
            public void create(MultivariateGaussian multivariateGaussian) {
                multivariateGaussian.setMean(getMean());
                multivariateGaussian.setCovarianceInverse(getCovarianceInverse());
            }

            public Matrix getCovarianceInverse() {
                return this.covarianceInverse;
            }

            public Vector getZ() {
                return this.z;
            }

            public Vector getMean() {
                return this.covarianceInverse.inverse().times(this.z);
            }

            public int getDimensionality() {
                return getZ().getDimensionality();
            }
        }

        public IncrementalEstimator(int i) {
            super(i);
        }

        public IncrementalEstimator(int i, Evaluator<? super InputType, Vector> evaluator) {
            this(i);
            setFeatureMap(evaluator);
        }

        public IncrementalEstimator(Evaluator<? super InputType, Vector> evaluator, double d, MultivariateGaussian multivariateGaussian) {
            super(evaluator, d, multivariateGaussian);
        }

        @Override // gov.sandia.cognition.learning.algorithm.IncrementalLearner
        public IncrementalEstimator<InputType>.SufficientStatistic createInitialLearnedObject() {
            return new SufficientStatistic(getWeightPrior());
        }

        @Override // gov.sandia.cognition.statistics.bayesian.BayesianLinearRegression, gov.sandia.cognition.learning.algorithm.BatchLearner
        public MultivariateGaussian.PDF learn(Collection<? extends InputOutputPair<? extends InputType, Double>> collection) {
            IncrementalEstimator<InputType>.SufficientStatistic createInitialLearnedObject = createInitialLearnedObject();
            update((SufficientStatistic) createInitialLearnedObject, (Iterable) collection);
            return createInitialLearnedObject.create();
        }

        @Override // gov.sandia.cognition.learning.algorithm.IncrementalLearner
        public void update(IncrementalEstimator<InputType>.SufficientStatistic sufficientStatistic, InputOutputPair<? extends InputType, Double> inputOutputPair) {
            sufficientStatistic.update((InputOutputPair) inputOutputPair);
        }

        /* JADX WARN: Multi-variable type inference failed */
        @Override // gov.sandia.cognition.learning.algorithm.IncrementalLearner
        public void update(IncrementalEstimator<InputType>.SufficientStatistic sufficientStatistic, Iterable<? extends InputOutputPair<? extends InputType, Double>> iterable) {
            sufficientStatistic.update((Iterable) iterable);
        }

        @Override // gov.sandia.cognition.statistics.bayesian.BayesianLinearRegression, gov.sandia.cognition.statistics.bayesian.AbstractBayesianRegression, gov.sandia.cognition.util.AbstractCloneableSerializable, gov.sandia.cognition.util.CloneableSerializable
        /* renamed from: clone */
        public /* bridge */ /* synthetic */ AbstractBayesianRegression mo784clone() {
            return super.mo784clone();
        }

        @Override // gov.sandia.cognition.statistics.bayesian.BayesianLinearRegression, gov.sandia.cognition.statistics.bayesian.BayesianRegression
        public /* bridge */ /* synthetic */ Evaluator createPredictiveDistribution(Distribution distribution) {
            return super.createPredictiveDistribution((MultivariateGaussian) distribution);
        }

        /* JADX WARN: Multi-variable type inference failed */
        @Override // gov.sandia.cognition.statistics.bayesian.BayesianLinearRegression, gov.sandia.cognition.statistics.bayesian.BayesianRegression
        public /* bridge */ /* synthetic */ Distribution createConditionalDistribution(Object obj, Vector vector) {
            return super.createConditionalDistribution((IncrementalEstimator<InputType>) obj, vector);
        }

        @Override // gov.sandia.cognition.statistics.bayesian.BayesianLinearRegression, gov.sandia.cognition.statistics.bayesian.AbstractBayesianRegression, gov.sandia.cognition.util.AbstractCloneableSerializable, gov.sandia.cognition.util.CloneableSerializable
        /* renamed from: clone */
        public /* bridge */ /* synthetic */ CloneableSerializable mo784clone() {
            return super.mo784clone();
        }

        @Override // gov.sandia.cognition.statistics.bayesian.BayesianLinearRegression, gov.sandia.cognition.statistics.bayesian.AbstractBayesianRegression, gov.sandia.cognition.util.AbstractCloneableSerializable
        /* renamed from: clone */
        public /* bridge */ /* synthetic */ Object mo784clone() throws CloneNotSupportedException {
            return super.mo784clone();
        }
    }

    @PublicationReference(author = {"Christopher M. Bishop"}, title = "Pattern Recognition and Machine Learning", type = PublicationType.Book, year = 2006, pages = {156})
    /* loaded from: input_file:gov/sandia/cognition/statistics/bayesian/BayesianLinearRegression$PredictiveDistribution.class */
    public class PredictiveDistribution extends AbstractCloneableSerializable implements Evaluator<InputType, UnivariateGaussian.PDF> {
        private MultivariateGaussian posterior;

        public PredictiveDistribution(MultivariateGaussian multivariateGaussian) {
            this.posterior = multivariateGaussian;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // gov.sandia.cognition.evaluator.Evaluator
        public UnivariateGaussian.PDF evaluate(InputType inputtype) {
            Vector evaluate = BayesianLinearRegression.this.featureMap.evaluate(inputtype);
            return new UnivariateGaussian.PDF(evaluate.dotProduct(this.posterior.getMean()), evaluate.times(this.posterior.getCovariance()).dotProduct(evaluate) + BayesianLinearRegression.this.outputVariance);
        }

        @Override // gov.sandia.cognition.evaluator.Evaluator
        public /* bridge */ /* synthetic */ UnivariateGaussian.PDF evaluate(Object obj) {
            return evaluate((PredictiveDistribution) obj);
        }
    }

    public BayesianLinearRegression(int i) {
        this(null, 1.0d, new MultivariateGaussian(VectorFactory.getDefault().createVector(i), MatrixFactory.getDefault().createIdentity(i, i).scale(1.0d)));
    }

    public BayesianLinearRegression(Evaluator<? super InputType, Vector> evaluator, double d, MultivariateGaussian multivariateGaussian) {
        super(evaluator);
        setOutputVariance(d);
        setWeightPrior(multivariateGaussian);
    }

    @Override // gov.sandia.cognition.statistics.bayesian.AbstractBayesianRegression, gov.sandia.cognition.util.AbstractCloneableSerializable
    /* renamed from: clone */
    public BayesianLinearRegression<InputType> mo784clone() {
        BayesianLinearRegression<InputType> bayesianLinearRegression = (BayesianLinearRegression) super.mo784clone();
        bayesianLinearRegression.setWeightPrior((MultivariateGaussian) ObjectUtil.cloneSafe(getWeightPrior()));
        return bayesianLinearRegression;
    }

    @Override // gov.sandia.cognition.learning.algorithm.BatchLearner
    public MultivariateGaussian.PDF learn(Collection<? extends InputOutputPair<? extends InputType, Double>> collection) {
        MultivariateGaussian weightPrior = getWeightPrior();
        RingAccumulator ringAccumulator = new RingAccumulator();
        Matrix mo784clone = weightPrior.getCovarianceInverse().mo784clone();
        ringAccumulator.accumulate((RingAccumulator) mo784clone);
        RingAccumulator ringAccumulator2 = new RingAccumulator();
        ringAccumulator2.accumulate((RingAccumulator) mo784clone.times(weightPrior.getMean()));
        for (InputOutputPair<? extends InputType, Double> inputOutputPair : collection) {
            Vector evaluate = this.featureMap.evaluate(inputOutputPair.getInput());
            Vector mo784clone2 = evaluate.mo784clone();
            double weight = DatasetUtil.getWeight(inputOutputPair) / this.outputVariance;
            if (weight != 1.0d) {
                mo784clone2.scaleEquals(weight);
            }
            ringAccumulator.accumulate((RingAccumulator) evaluate.outerProduct(mo784clone2));
            double doubleValue = inputOutputPair.getOutput().doubleValue();
            if (doubleValue != 1.0d) {
                mo784clone2.scaleEquals(doubleValue);
            }
            ringAccumulator2.accumulate((RingAccumulator) mo784clone2);
        }
        Matrix inverse = ((Matrix) ringAccumulator.getSum()).inverse();
        return new MultivariateGaussian.PDF(inverse.times((Vector) ringAccumulator2.getSum()), inverse);
    }

    @Override // gov.sandia.cognition.statistics.bayesian.BayesianRegression
    public UnivariateGaussian createConditionalDistribution(InputType inputtype, Vector vector) {
        return new UnivariateGaussian(this.featureMap.evaluate(inputtype).dotProduct(vector), getOutputVariance());
    }

    public MultivariateGaussian getWeightPrior() {
        return this.weightPrior;
    }

    public void setWeightPrior(MultivariateGaussian multivariateGaussian) {
        this.weightPrior = multivariateGaussian;
    }

    public double getOutputVariance() {
        return this.outputVariance;
    }

    public void setOutputVariance(double d) {
        if (d <= 0.0d) {
            throw new IllegalArgumentException("outputVariance must be > 0.0");
        }
        this.outputVariance = d;
    }

    @Override // gov.sandia.cognition.statistics.bayesian.BayesianRegression
    public BayesianLinearRegression<InputType>.PredictiveDistribution createPredictiveDistribution(MultivariateGaussian multivariateGaussian) {
        return new PredictiveDistribution(multivariateGaussian);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // gov.sandia.cognition.statistics.bayesian.BayesianRegression
    public /* bridge */ /* synthetic */ Distribution createConditionalDistribution(Object obj, Vector vector) {
        return createConditionalDistribution((BayesianLinearRegression<InputType>) obj, vector);
    }
}
