package org.openimaj.math.statistics.distribution;

import Jama.Matrix;
import cern.jet.random.Normal;
import cern.jet.random.engine.MersenneTwister;
import java.util.Arrays;
import java.util.Random;
import org.openimaj.math.matrix.MatrixUtils;

/* loaded from: input_file:org/openimaj/math/statistics/distribution/DiagonalMultivariateGaussian.class */
public class DiagonalMultivariateGaussian extends AbstractMultivariateGaussian {
    public double[] variance;

    public DiagonalMultivariateGaussian(Matrix matrix, double[] dArr) {
        this.mean = matrix;
        this.variance = dArr;
    }

    public DiagonalMultivariateGaussian(int i) {
        this.mean = new Matrix(1, i);
        this.variance = new double[i];
        Arrays.fill(this.variance, 1.0d);
    }

    @Override // org.openimaj.math.statistics.distribution.MultivariateGaussian
    public Matrix getCovariance() {
        return MatrixUtils.diag(this.variance);
    }

    @Override // org.openimaj.math.statistics.distribution.MultivariateGaussian
    public double getCovariance(int i, int i2) {
        if (i < 0 || i >= this.variance.length || i2 < 0 || i2 > this.variance.length) {
            throw new IndexOutOfBoundsException();
        }
        if (i == i2) {
            return this.variance[i];
        }
        return 0.0d;
    }

    @Override // org.openimaj.math.statistics.distribution.AbstractMultivariateGaussian, org.openimaj.math.statistics.distribution.MultivariateDistribution
    public double estimateProbability(double[] dArr) {
        int length = this.variance.length;
        double[] dArr2 = this.mean.getArray()[0];
        double d = this.variance[0];
        for (int i = 1; i < length; i++) {
            d *= this.variance[i];
        }
        double sqrt = 1.0d / Math.sqrt(Math.pow(6.283185307179586d, length) * d);
        double d2 = 0.0d;
        for (int i2 = 0; i2 < length; i2++) {
            double d3 = dArr[i2] - dArr2[i2];
            d2 += (d3 * d3) / this.variance[i2];
        }
        return sqrt * Math.exp((-0.5d) * d2);
    }

    @Override // org.openimaj.math.statistics.distribution.AbstractMultivariateGaussian, org.openimaj.math.statistics.distribution.MultivariateDistribution
    public double estimateLogProbability(double[] dArr) {
        int length = this.variance.length;
        double[] dArr2 = this.mean.getArray()[0];
        double log = Math.log(Math.sqrt(this.variance[0]));
        for (int i = 1; i < length; i++) {
            log += Math.log(Math.sqrt(this.variance[i]));
        }
        double d = (-Math.log(Math.sqrt(Math.pow(6.283185307179586d, length)))) - log;
        double d2 = 0.0d;
        for (int i2 = 0; i2 < length; i2++) {
            double d3 = dArr[i2] - dArr2[i2];
            d2 += (d3 * d3) / this.variance[i2];
        }
        return d + ((-0.5d) * d2);
    }

    @Override // org.openimaj.math.statistics.distribution.AbstractMultivariateGaussian, org.openimaj.math.statistics.distribution.MultivariateDistribution
    public double[] estimateLogProbability(double[][] dArr) {
        int length = this.variance.length;
        double[] dArr2 = this.mean.getArray()[0];
        double log = Math.log(Math.sqrt(this.variance[0]));
        for (int i = 1; i < length; i++) {
            log += Math.log(Math.sqrt(this.variance[i]));
        }
        double d = (-Math.log(Math.sqrt(Math.pow(6.283185307179586d, length)))) - log;
        double[] dArr3 = new double[dArr.length];
        for (int i2 = 0; i2 < dArr.length; i2++) {
            double d2 = 0.0d;
            for (int i3 = 0; i3 < length; i3++) {
                double d3 = dArr[i2][i3] - dArr2[i3];
                d2 += (d3 * d3) / this.variance[i3];
            }
            dArr3[i2] = d + ((-0.5d) * d2);
        }
        return dArr3;
    }

    @Override // org.openimaj.math.statistics.distribution.AbstractMultivariateGaussian, org.openimaj.math.statistics.distribution.MultivariateDistribution
    public double[][] sample(int i, Random random) {
        if (i == 0) {
            return new double[0][0];
        }
        Normal normal = new Normal(0.0d, 1.0d, new MersenneTwister());
        int columnDimension = this.mean.getColumnDimension();
        double[][] dArr = new double[i][columnDimension];
        double[] dArr2 = this.mean.getArray()[0];
        for (int i2 = 0; i2 < columnDimension; i2++) {
            double sqrt = Math.sqrt(this.variance[i2]);
            for (int i3 = 0; i3 < i; i3++) {
                dArr[i3][i2] = (sqrt * normal.nextDouble()) + dArr2[i2];
            }
        }
        return dArr;
    }
}
