package org.openimaj.demos;

import Jama.Matrix;
import com.jmatio.io.MatFileReader;
import com.jmatio.io.MatFileWriter;
import com.jmatio.types.MLArray;
import com.jmatio.types.MLDouble;
import com.jmatio.types.MLSingle;
import com.jmatio.types.MLStructure;
import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.Iterator;
import org.openimaj.feature.FloatFV;
import org.openimaj.feature.local.list.MemoryLocalFeatureList;
import org.openimaj.image.feature.dense.gradient.dsift.FloatDSIFTKeypoint;
import org.openimaj.image.feature.local.aggregate.FisherVector;
import org.openimaj.math.matrix.algorithm.pca.PrincipalComponentAnalysis;
import org.openimaj.math.matrix.algorithm.pca.ThinSvdPrincipalComponentAnalysis;
import org.openimaj.math.statistics.distribution.DiagonalMultivariateGaussian;
import org.openimaj.math.statistics.distribution.MixtureOfGaussians;
import org.openimaj.math.statistics.distribution.MultivariateGaussian;
import org.openimaj.util.array.ArrayUtils;

/* loaded from: input_file:org/openimaj/demos/FVFWCheckPCAGMM.class */
public class FVFWCheckPCAGMM {
    private static final String GMM_MATLAB_FILE = "/Users/ss/Experiments/FVFW/data/gmm_512.mat";
    private static final String PCA_MATLAB_FILE = "/Users/ss/Experiments/FVFW/data/PCA_64.mat";
    private static final String[] FACE_DSIFTS = {"/Users/ss/Experiments/FVFW/data/Aaron_Eckhart_0001-pdfsift.bin"};

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/openimaj/demos/FVFWCheckPCAGMM$LoadedPCA.class */
    public static class LoadedPCA extends ThinSvdPrincipalComponentAnalysis {
        public LoadedPCA(Matrix matrix, double[] dArr) {
            super(matrix.getRowDimension());
            this.basis = matrix;
            this.mean = dArr;
        }
    }

    public static void main(String[] strArr) throws IOException {
        MixtureOfGaussians loadMoG = loadMoG(new File(GMM_MATLAB_FILE));
        PrincipalComponentAnalysis loadPCA = loadPCA(new File(PCA_MATLAB_FILE));
        FisherVector fisherVector = new FisherVector(loadMoG, true, true);
        for (String str : FACE_DSIFTS) {
            MemoryLocalFeatureList read = MemoryLocalFeatureList.read(new File(str), FloatDSIFTKeypoint.class);
            projectPCA(read, loadPCA);
            FloatFV aggregate = fisherVector.aggregate(read);
            System.out.println(String.format("%s: %s", str, aggregate));
            System.out.println("Writing...");
            new MatFileWriter(new File(str + ".fisher.mat"), Arrays.asList(toMLArray(aggregate)));
        }
    }

    private static MLArray toMLArray(FloatFV floatFV) {
        MLDouble mLDouble = new MLDouble("fisherface", new int[]{((float[]) floatFV.values).length, 1});
        for (int i = 0; i < ((float[]) floatFV.values).length; i++) {
            mLDouble.set(Double.valueOf(((float[]) floatFV.values)[i]), i, 0);
        }
        return mLDouble;
    }

    private static void projectPCA(MemoryLocalFeatureList<FloatDSIFTKeypoint> memoryLocalFeatureList, PrincipalComponentAnalysis principalComponentAnalysis) {
        Iterator it = memoryLocalFeatureList.iterator();
        while (it.hasNext()) {
            FloatDSIFTKeypoint floatDSIFTKeypoint = (FloatDSIFTKeypoint) it.next();
            floatDSIFTKeypoint.descriptor = ArrayUtils.convertToFloat(principalComponentAnalysis.project(ArrayUtils.convertToDouble((float[]) floatDSIFTKeypoint.descriptor)));
            int length = ((float[]) floatDSIFTKeypoint.descriptor).length;
            floatDSIFTKeypoint.descriptor = Arrays.copyOf((float[]) floatDSIFTKeypoint.descriptor, length + 2);
            ((float[]) floatDSIFTKeypoint.descriptor)[length] = (floatDSIFTKeypoint.x / 125.0f) - 0.5f;
            ((float[]) floatDSIFTKeypoint.descriptor)[length + 1] = (floatDSIFTKeypoint.y / 160.0f) - 0.5f;
        }
        memoryLocalFeatureList.resetVecLength();
    }

    public static PrincipalComponentAnalysis loadPCA(File file) throws IOException {
        MatFileReader matFileReader = new MatFileReader(file);
        MLSingle mLSingle = (MLSingle) matFileReader.getContent().get("proj");
        Matrix matrix = new Matrix(mLSingle.getM(), mLSingle.getN());
        double[] dArr = new double[mLSingle.getN()];
        for (int i = 0; i < mLSingle.getN(); i++) {
            dArr[i] = 0.0d;
            for (int i2 = 0; i2 < mLSingle.getM(); i2++) {
                matrix.set(i2, i, ((Float) mLSingle.get(i2, i)).floatValue());
            }
        }
        return new LoadedPCA(matrix.transpose(), dArr);
    }

    public static MixtureOfGaussians loadMoG(File file) throws IOException {
        MLStructure mLStructure = (MLStructure) new MatFileReader(file).getContent().get("codebook");
        MLSingle field = mLStructure.getField("mean");
        MLSingle field2 = mLStructure.getField("variance");
        MLSingle field3 = mLStructure.getField("coef");
        int n = field.getN();
        int m = field.getM();
        MultivariateGaussian[] multivariateGaussianArr = new MultivariateGaussian[n];
        double[] dArr = new double[n];
        for (int i = 0; i < n; i++) {
            dArr[i] = ((Float) field3.get(i, 0)).floatValue();
            DiagonalMultivariateGaussian diagonalMultivariateGaussian = new DiagonalMultivariateGaussian(m);
            for (int i2 = 0; i2 < m; i2++) {
                diagonalMultivariateGaussian.mean.set(0, i2, ((Float) field.get(i2, i)).floatValue());
                diagonalMultivariateGaussian.variance[i2] = ((Float) field2.get(i2, i)).floatValue();
            }
            multivariateGaussianArr[i] = diagonalMultivariateGaussian;
        }
        return new MixtureOfGaussians(multivariateGaussianArr, dArr);
    }
}
