package org.openimaj.demos.sandbox.gmm;

import Jama.Matrix;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.openimaj.data.RandomData;
import org.openimaj.demos.sandbox.gmm.GaussianMixtureModelGenerator;
import org.openimaj.image.DisplayUtilities;
import org.openimaj.image.MBFImage;
import org.openimaj.image.colour.RGBColour;
import org.openimaj.math.geometry.point.Point2dImpl;
import org.openimaj.math.geometry.shape.Ellipse;
import org.openimaj.math.geometry.shape.EllipseUtilities;
import org.openimaj.math.statistics.distribution.MultivariateGaussian;

/* loaded from: input_file:org/openimaj/demos/sandbox/gmm/GaussianMixtureModelEM.class */
public class GaussianMixtureModelEM {
    private static final double SMALL = 1.0E-8d;
    private int nGaus;
    private double[][] data;
    private Matrix gausPosterior;
    private double[] gausPrior;
    private List<MultivariateGaussian> gaussians;
    private double[] sumPosterior;

    public GaussianMixtureModelEM(double[][] dArr, int i) {
        this.data = dArr;
        this.nGaus = i;
        init();
    }

    public GaussianMixtureModelEM(GaussianMixtureModelGenerator gaussianMixtureModelGenerator, int i, int i2) {
        this.data = new double[i][gaussianMixtureModelGenerator.dimentions()];
        for (int i3 = 0; i3 < this.data.length; i3++) {
            System.arraycopy(gaussianMixtureModelGenerator.generate().point, 0, this.data[i3], 0, this.data[i3].length);
        }
        this.nGaus = i2;
        init();
    }

    /* JADX WARN: Type inference failed for: r2v11, types: [double[], double[][]] */
    private void init() {
        this.gausPosterior = new Matrix(this.data.length, this.nGaus);
        this.gausPrior = new double[this.nGaus];
        this.gaussians = new ArrayList();
        this.sumPosterior = new double[this.nGaus];
        int[] randomIntArray = RandomData.getRandomIntArray(this.nGaus, 0, this.data.length);
        Matrix times = Matrix.identity(this.data[0].length, this.data[0].length).times(800.0d);
        for (int i = 0; i < this.nGaus; i++) {
            this.gausPrior[i] = 1.0f / this.nGaus;
            double[] dArr = this.data[randomIntArray[i]];
            this.gaussians.add(new MultivariateGaussian(new Matrix((double[][]) new double[]{Arrays.copyOf(dArr, dArr.length)}), times.copy()));
        }
    }

    private void e_step() {
        float[] fArr = new float[this.data[0].length];
        double[] dArr = new double[this.nGaus];
        this.sumPosterior = new double[this.nGaus];
        for (int i = 0; i < this.data.length; i++) {
            double[] dArr2 = this.data[i];
            double d = 0.0d;
            for (int i2 = 0; i2 < this.nGaus; i2++) {
                dArr[i2] = this.gaussians.get(i2).estimateProbability(dataAsFloat(dArr2, fArr));
                d += this.gausPrior[i2] * dArr[i2];
            }
            if (d < SMALL) {
                d = 1.0d;
            }
            for (int i3 = 0; i3 < this.nGaus; i3++) {
                double d2 = (this.gausPrior[i3] * dArr[i3]) / d;
                if (Double.isNaN(d2)) {
                    System.err.println("NaN!");
                }
                this.gausPosterior.set(i, i3, d2);
                double[] dArr3 = this.sumPosterior;
                int i4 = i3;
                dArr3[i4] = dArr3[i4] + d2;
            }
        }
    }

    /* JADX WARN: Type inference failed for: r2v13, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r2v7, types: [double[], double[][]] */
    private void m_step() {
        for (int i = 0; i < this.nGaus; i++) {
            Matrix matrix = new Matrix(1, this.data[0].length);
            for (int i2 = 0; i2 < this.data.length; i2++) {
                matrix.plusEquals(new Matrix((double[][]) new double[]{this.data[i2]}).times(this.gausPosterior.get(i2, i)));
            }
            double d = this.sumPosterior[i];
            matrix.timesEquals(1.0d / d);
            Matrix times = this.gaussians.get(i).getCovar().times(0.0d);
            for (int i3 = 0; i3 < this.data.length; i3++) {
                Matrix transpose = new Matrix((double[][]) new double[]{this.data[i3]}).minus(matrix).transpose();
                times.plusEquals(transpose.times(transpose.transpose()).times(this.gausPosterior.get(i3, i)));
            }
            times.timesEquals(1.0d / d);
            this.gaussians.set(i, new MultivariateGaussian(matrix, times));
            this.gausPrior[i] = d / this.data.length;
        }
    }

    public void step() {
        e_step();
        m_step();
    }

    private float[] dataAsFloat(double[] dArr, float[] fArr) {
        for (int i = 0; i < fArr.length; i++) {
            fArr[i] = (float) dArr[i];
        }
        return fArr;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static void main(String[] strArr) throws InterruptedException {
        Ellipse ellipse = new Ellipse(200.0d, 200.0d, 40.0d, 20.0d, 1.0471975511965976d);
        Ellipse ellipse2 = new Ellipse(220.0d, 150.0d, 60.0d, 20.0d, -1.0471975511965976d);
        Ellipse ellipse3 = new Ellipse(180.0d, 200.0d, 80.0d, 20.0d, -1.0471975511965976d);
        Float[] fArr = {RGBColour.RED, RGBColour.GREEN, RGBColour.BLUE};
        MBFImage mBFImage = new MBFImage(400, 400, 3);
        mBFImage.drawShape(ellipse, RGBColour.RED);
        mBFImage.drawShape(ellipse2, RGBColour.GREEN);
        mBFImage.drawShape(ellipse3, RGBColour.BLUE);
        GaussianMixtureModelGenerator2D gaussianMixtureModelGenerator2D = new GaussianMixtureModelGenerator2D(ellipse, ellipse2, ellipse3);
        MBFImage clone = mBFImage.clone();
        MBFImage clone2 = mBFImage.clone();
        MBFImage clone3 = mBFImage.clone();
        double[][] dArr = new double[1000][2];
        for (int i = 0; i < 1000; i++) {
            GaussianMixtureModelGenerator.Generated generate = gaussianMixtureModelGenerator2D.generate();
            dArr[i] = generate.point;
            Point2dImpl point2dImpl = new Point2dImpl((float) generate.point[0], (float) generate.point[1]);
            clone.drawPoint(point2dImpl, fArr[generate.distribution], 3);
            Float[] fArr2 = new Float[3];
            for (int i2 = 0; i2 < fArr2.length; i2++) {
                fArr2[i2] = Float.valueOf(0.0f);
            }
            for (int i3 = 0; i3 < fArr.length; i3++) {
                for (int i4 = 0; i4 < fArr[i3].length; i4++) {
                    fArr2[i4] = Float.valueOf((float) (fArr2[i4].floatValue() + (fArr[i3][i4].floatValue() * generate.responsibilities[i3])));
                }
            }
            double d = 0.0d;
            for (Float f : fArr2) {
                d += f.floatValue();
            }
            for (int i5 = 0; i5 < fArr2.length; i5++) {
                fArr2[i5] = Float.valueOf((float) (fArr2[i5].floatValue() / d));
            }
            clone2.drawPoint(point2dImpl, fArr2, 3);
        }
        DisplayUtilities.display(clone);
        DisplayUtilities.display(clone2);
        GaussianMixtureModelEM gaussianMixtureModelEM = new GaussianMixtureModelEM(dArr, 3);
        for (int i6 = 0; i6 < 100; i6++) {
            clone3.fill(RGBColour.BLACK);
            for (int i7 = 0; i7 < 3; i7++) {
                clone3.drawShape(EllipseUtilities.ellipseFromGaussian(gaussianMixtureModelEM.gaussians.get(i7), 2.0f), fArr[i7]);
            }
            DisplayUtilities.displayName(clone3, "EM Progress");
            Thread.sleep(500L);
            gaussianMixtureModelEM.step();
        }
    }
}
