package org.openimaj.math.matrix.algorithm;

import Jama.Matrix;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.openimaj.citation.annotation.Reference;
import org.openimaj.citation.annotation.ReferenceType;
import org.openimaj.math.matrix.GeneralisedEigenvalueProblem;
import org.openimaj.math.matrix.MatrixUtils;
import org.openimaj.util.array.ArrayUtils;
import org.openimaj.util.pair.IndependentPair;

@Reference(type = ReferenceType.Article, author = {"Fisher, Ronald A."}, title = "{The use of multiple measurements in taxonomic problems}", year = "1936", journal = "Annals Eugen.", pages = {"179", "", "188"}, volume = "7", customData = {"citeulike-article-id", "764226", "keywords", "classification", "posted-at", "2006-09-18 14:06:16", "priority", "2"})
/* loaded from: input_file:org/openimaj/math/matrix/algorithm/LinearDiscriminantAnalysis.class */
public class LinearDiscriminantAnalysis {
    protected int numComponents;
    protected Matrix eigenvectors;
    protected double[] eigenvalues;
    protected double[] mean;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/openimaj/math/matrix/algorithm/LinearDiscriminantAnalysis$MeanData.class */
    public static class MeanData {
        double[] overallMean;
        double[][] classMeans;
        int numInstances;

        private MeanData() {
        }
    }

    public LinearDiscriminantAnalysis(int i) {
        this.numComponents = i;
    }

    /* JADX WARN: Type inference failed for: r1v6, types: [double[], double[][]] */
    private MeanData computeMeans(List<double[][]> list) {
        int length = list.get(0)[0].length;
        int size = list.size();
        MeanData meanData = new MeanData();
        meanData.overallMean = new double[length];
        meanData.classMeans = new double[size];
        meanData.numInstances = 0;
        for (int i = 0; i < size; i++) {
            double[][] dArr = list.get(i);
            int length2 = dArr.length;
            meanData.classMeans[i] = computeSum(dArr);
            meanData.numInstances += length2;
            for (int i2 = 0; i2 < length; i2++) {
                double[] dArr2 = meanData.overallMean;
                int i3 = i2;
                dArr2[i3] = dArr2[i3] + meanData.classMeans[i][i2];
                double[] dArr3 = meanData.classMeans[i];
                int i4 = i2;
                dArr3[i4] = dArr3[i4] / length2;
            }
        }
        for (int i5 = 0; i5 < length; i5++) {
            double[] dArr4 = meanData.overallMean;
            int i6 = i5;
            dArr4[i6] = dArr4[i6] / meanData.numInstances;
        }
        return meanData;
    }

    private double[] computeSum(double[][] dArr) {
        double[] dArr2 = new double[dArr[0].length];
        for (double[] dArr3 : dArr) {
            for (int i = 0; i < dArr2.length; i++) {
                int i2 = i;
                dArr2[i2] = dArr2[i2] + dArr3[i];
            }
        }
        return dArr2;
    }

    public void learnBasisIP(List<? extends IndependentPair<?, double[]>> list) {
        HashMap hashMap = new HashMap();
        for (IndependentPair<?, double[]> independentPair : list) {
            List<double[]> list2 = hashMap.get(independentPair.firstObject());
            if (list2 == null) {
                Object firstObject = independentPair.firstObject();
                ArrayList arrayList = new ArrayList();
                list2 = arrayList;
                hashMap.put(firstObject, arrayList);
            }
            list2.add(independentPair.getSecondObject());
        }
        learnBasisML(hashMap);
    }

    public void learnBasisML(Map<?, List<double[]>> map) {
        ArrayList arrayList = new ArrayList();
        for (Map.Entry<?, List<double[]>> entry : map.entrySet()) {
            arrayList.add(entry.getValue().toArray((Object[]) new double[entry.getValue().size()]));
        }
        learnBasis(arrayList);
    }

    public void learnBasisLL(List<List<double[]>> list) {
        ArrayList arrayList = new ArrayList();
        for (List<double[]> list2 : list) {
            arrayList.add(list2.toArray((Object[]) new double[list2.size()]));
        }
        learnBasis(arrayList);
    }

    public void learnBasis(Map<?, double[][]> map) {
        ArrayList arrayList = new ArrayList();
        Iterator<Map.Entry<?, double[][]>> it = map.entrySet().iterator();
        while (it.hasNext()) {
            arrayList.add(it.next().getValue());
        }
        learnBasis(map);
    }

    /* JADX WARN: Type inference failed for: r2v14, types: [double[], double[][]] */
    public void learnBasis(List<double[][]> list) {
        int size = list.size();
        if (size < 0 || this.numComponents >= size) {
            this.numComponents = size - 1;
        }
        MeanData computeMeans = computeMeans(list);
        this.mean = computeMeans.overallMean;
        double[][] dArr = computeMeans.classMeans;
        Matrix matrix = new Matrix(this.mean.length, this.mean.length);
        Matrix matrix2 = new Matrix(this.mean.length, this.mean.length);
        for (int i = 0; i < size; i++) {
            Matrix matrix3 = new Matrix(list.get(i));
            double[] dArr2 = dArr[i];
            Matrix minusRow = MatrixUtils.minusRow(matrix3, dArr2);
            MatrixUtils.plusEquals(matrix, minusRow.transpose().times(minusRow));
            ArrayUtils.subtract(dArr2, this.mean);
            Matrix matrix4 = new Matrix((double[][]) new double[]{dArr2});
            MatrixUtils.plusEquals(matrix2, MatrixUtils.times(matrix4.transpose().times(matrix4), computeMeans.numInstances));
        }
        IndependentPair<Matrix, double[]> symmetricGeneralisedEigenvectorsSorted = GeneralisedEigenvalueProblem.symmetricGeneralisedEigenvectorsSorted(matrix2, matrix, this.numComponents);
        this.eigenvectors = (Matrix) symmetricGeneralisedEigenvectorsSorted.firstObject();
        this.eigenvalues = (double[]) symmetricGeneralisedEigenvectorsSorted.secondObject();
    }

    public Matrix getBasis() {
        return this.eigenvectors;
    }

    public double[] getBasisVector(int i) {
        double[] dArr = new double[this.eigenvectors.getRowDimension()];
        double[][] array = this.eigenvectors.getArray();
        for (int i2 = 0; i2 < dArr.length; i2++) {
            dArr[i2] = array[i2][i];
        }
        return dArr;
    }

    public Matrix getEigenVectors() {
        return this.eigenvectors;
    }

    public double[] getEigenValues() {
        return this.eigenvalues;
    }

    public double getEigenValue(int i) {
        return this.eigenvalues[i];
    }

    public double[] getMean() {
        return this.mean;
    }

    /* JADX WARN: Type inference failed for: r2v6, types: [double[], double[][]] */
    public double[] generate(double[] dArr) {
        Matrix matrix = new Matrix(this.eigenvalues.length, 1);
        for (int i = 0; i < Math.min(this.eigenvalues.length, dArr.length); i++) {
            matrix.set(i, 0, dArr[i]);
        }
        return new Matrix((double[][]) new double[]{this.mean}).transpose().plus(this.eigenvectors.times(matrix)).getColumnPackedCopy();
    }

    public Matrix project(Matrix matrix) {
        Matrix copy = matrix.copy();
        int rowDimension = copy.getRowDimension();
        int columnDimension = copy.getColumnDimension();
        double[][] array = copy.getArray();
        for (int i = 0; i < rowDimension; i++) {
            for (int i2 = 0; i2 < columnDimension; i2++) {
                double[] dArr = array[i];
                int i3 = i2;
                dArr[i3] = dArr[i3] - this.mean[i2];
            }
        }
        return copy.times(this.eigenvectors);
    }

    public double[] project(double[] dArr) {
        Matrix matrix = new Matrix(1, dArr.length);
        double[][] array = matrix.getArray();
        for (int i = 0; i < dArr.length; i++) {
            array[0][i] = dArr[i] - this.mean[i];
        }
        return matrix.times(this.eigenvectors).getColumnPackedCopy();
    }
}
