package org.openimaj.math.statistics.distribution;

import Jama.CholeskyDecomposition;
import Jama.Matrix;
import java.util.Arrays;
import java.util.Random;
import org.openimaj.math.matrix.MatrixUtils;
import org.openimaj.util.array.ArrayUtils;
import org.openimaj.util.pair.IndependentPair;

/* loaded from: input_file:org/openimaj/math/statistics/distribution/MixtureOfGaussians.class */
public class MixtureOfGaussians extends AbstractMultivariateDistribution {
    public static final double MIN_COVAR_RECONDITION = 1.0E-7d;
    public MultivariateGaussian[] gaussians;
    public double[] weights;

    public MixtureOfGaussians(MultivariateGaussian[] multivariateGaussianArr, double[] dArr) {
        this.gaussians = multivariateGaussianArr;
        this.weights = dArr;
    }

    @Override // org.openimaj.math.statistics.distribution.MultivariateDistribution
    public double[] sample(Random random) {
        return sample(1, random)[0];
    }

    @Override // org.openimaj.math.statistics.distribution.AbstractMultivariateDistribution, org.openimaj.math.statistics.distribution.MultivariateDistribution
    public double[][] sample(int i, Random random) {
        double[] cumulativeSum = ArrayUtils.cumulativeSum(this.weights);
        double[][] dArr = new double[i][this.gaussians[0].getMean().getColumnDimension()];
        int[] iArr = new int[i];
        for (int i2 = 0; i2 < i; i2++) {
            iArr[i2] = Arrays.binarySearch(cumulativeSum, random.nextDouble());
            if (iArr[i2] < 0) {
                iArr[i2] = 0;
            }
            if (iArr[i2] >= this.gaussians.length) {
                iArr[i2] = this.gaussians.length - 1;
            }
        }
        for (int i3 = 0; i3 < this.gaussians.length; i3++) {
            int[] search = ArrayUtils.search(iArr, i3);
            if (search.length != 0) {
                double[][] sample = this.gaussians[i3].sample(search.length, random);
                for (int i4 = 0; i4 < sample.length; i4++) {
                    dArr[search[i4]] = sample[i4];
                }
            }
        }
        return dArr;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r1v1, types: [double[], double[][]] */
    @Override // org.openimaj.math.statistics.distribution.AbstractMultivariateDistribution, org.openimaj.math.statistics.distribution.MultivariateDistribution
    public double estimateLogProbability(double[] dArr) {
        return estimateLogProbability((double[][]) new double[]{dArr})[0];
    }

    public double[] estimateLogProbability(double[][] dArr) {
        if (dArr[0].length != this.gaussians[0].getMean().getColumnDimension()) {
            throw new IllegalArgumentException("The number of dimensions of the given data is not compatible with the model");
        }
        double[][] computeWeightedLogProb = computeWeightedLogProb(dArr);
        double[] dArr2 = new double[dArr.length];
        for (int i = 0; i < dArr.length; i++) {
            for (int i2 = 0; i2 < computeWeightedLogProb[0].length; i2++) {
                int i3 = i;
                dArr2[i3] = dArr2[i3] + Math.exp(computeWeightedLogProb[i][i2]);
            }
            dArr2[i] = Math.log(dArr2[i]);
        }
        return dArr2;
    }

    protected double[][] computeWeightedLogProb(double[][] dArr) {
        double[][] logProbability = logProbability(dArr, this.gaussians);
        for (int i = 0; i < logProbability[0].length; i++) {
            double log = Math.log(this.weights[i]);
            for (double[] dArr2 : logProbability) {
                int i2 = i;
                dArr2[i2] = dArr2[i2] + log;
            }
        }
        return logProbability;
    }

    public static double[][] logProbability(double[][] dArr, MultivariateGaussian[] multivariateGaussianArr) {
        int length = dArr[0].length;
        int length2 = multivariateGaussianArr.length;
        int length3 = dArr.length;
        Matrix matrix = new Matrix(dArr);
        double[][] dArr2 = new double[length3][length2];
        for (int i = 0; i < length2; i++) {
            Matrix mean = multivariateGaussianArr[i].getMean();
            Matrix covariance = multivariateGaussianArr[i].getCovariance();
            CholeskyDecomposition chol = covariance.chol();
            Matrix l = chol.isSPD() ? chol.getL() : covariance.plus(Matrix.identity(length, length).timesEquals(1.0E-7d)).chol().getL();
            double d = 0.0d;
            double[][] array = l.getArray();
            for (int i2 = 0; i2 < length; i2++) {
                d += Math.log(array[i2][i2]);
            }
            double d2 = d * 2.0d;
            Matrix transpose = l.solve(MatrixUtils.minusRow(matrix, mean.getArray()[0]).transpose()).transpose();
            for (int i3 = 0; i3 < length3; i3++) {
                double d3 = 0.0d;
                for (int i4 = 0; i4 < length; i4++) {
                    d3 += transpose.get(i3, i4) * transpose.get(i3, i4);
                }
                dArr2[i3][i] = (-0.5d) * (d3 + d2 + (length * Math.log(6.283185307179586d)));
            }
        }
        return dArr2;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r1v1, types: [double[], double[][]] */
    public double[] predictLogPosterior(double[] dArr) {
        return predictLogPosterior((double[][]) new double[]{dArr})[0];
    }

    public double[][] predictLogPosterior(double[][] dArr) {
        return (double[][]) scoreSamples(dArr).secondObject();
    }

    public IndependentPair<double[], double[][]> scoreSamples(double[][] dArr) {
        if (dArr[0].length != this.gaussians[0].getMean().getColumnDimension()) {
            throw new IllegalArgumentException("The number of dimensions of the given data is not compatible with the model");
        }
        double[][] computeWeightedLogProb = computeWeightedLogProb(dArr);
        double[] dArr2 = new double[dArr.length];
        for (int i = 0; i < dArr.length; i++) {
            for (int i2 = 0; i2 < this.gaussians.length; i2++) {
                int i3 = i;
                dArr2[i3] = dArr2[i3] + Math.exp(computeWeightedLogProb[i][i2]);
            }
            dArr2[i] = Math.log(dArr2[i]);
        }
        double[][] dArr3 = new double[dArr.length][this.gaussians.length];
        for (int i4 = 0; i4 < dArr.length; i4++) {
            for (int i5 = 0; i5 < this.gaussians.length; i5++) {
                dArr3[i4][i5] = Math.exp(computeWeightedLogProb[i4][i5] - dArr2[i4]);
            }
        }
        return IndependentPair.pair(dArr2, dArr3);
    }

    public MultivariateGaussian[] getGaussians() {
        return this.gaussians;
    }

    public double[] getWeights() {
        return this.weights;
    }

    @Override // org.openimaj.math.statistics.distribution.MultivariateDistribution
    public double estimateProbability(double[] dArr) {
        return Math.exp(estimateLogProbability(dArr));
    }

    public int predict(double[] dArr) {
        return ArrayUtils.maxIndex(predictLogPosterior(dArr));
    }
}
