package org.openimaj.ml.linear.projection;

import Jama.Matrix;
import org.openimaj.math.matrix.MatrixUtils;
import org.openimaj.math.matrix.algorithm.pca.ThinSvdPrincipalComponentAnalysis;

/* loaded from: input_file:org/openimaj/ml/linear/projection/LargeMarginDimensionalityReduction.class */
public class LargeMarginDimensionalityReduction {
    public int ndims;
    public double wLearnRate = 0.01d;
    public double bLearnRate = 0.0d;
    public Matrix W;
    public double b;
    public Matrix WtW;

    public LargeMarginDimensionalityReduction(int i) {
        this.ndims = i;
    }

    public void initialise(double[][] dArr) {
        ThinSvdPrincipalComponentAnalysis thinSvdPrincipalComponentAnalysis = new ThinSvdPrincipalComponentAnalysis(this.ndims);
        thinSvdPrincipalComponentAnalysis.learnBasis(dArr);
        this.W = thinSvdPrincipalComponentAnalysis.getBasis().times(MatrixUtils.diag(thinSvdPrincipalComponentAnalysis.getEigenValues()).inverse()).transpose().times(100.0d);
        this.WtW = this.W.transpose().times(this.W);
        double d = 0.0d;
        int i = 0;
        for (int i2 = 0; i2 < dArr.length; i2++) {
            for (int i3 = i2 + 1; i3 < dArr.length; i3++) {
                d += computeDistance(dArr[i2], dArr[i3]);
                i++;
            }
        }
        this.b = 1000.0d;
    }

    private double computeDistance(double[] dArr, double[] dArr2) {
        Matrix matrix = new Matrix(dArr.length, 1);
        for (int i = 0; i < dArr.length; i++) {
            matrix.set(i, 0, dArr[i] - dArr2[i]);
        }
        return matrix.transpose().times(this.WtW).times(matrix).get(0, 0);
    }

    public boolean step(double[] dArr, double[] dArr2, boolean z) {
        int i = z ? 1 : -1;
        if (i * (this.b - computeDistance(dArr, dArr2)) > 1.0d) {
            return false;
        }
        System.out.println(z + " " + computeDistance(dArr, dArr2));
        Matrix matrix = new Matrix(dArr.length, 1);
        for (int i2 = 0; i2 < dArr.length; i2++) {
            matrix.set(i2, 0, dArr[i2] - dArr2[i2]);
        }
        this.W.minusEquals(this.W.times(matrix.times(matrix.transpose())).times(this.wLearnRate * i));
        this.WtW = this.W.transpose().times(this.W);
        this.b -= i * this.bLearnRate;
        System.out.println(MatrixUtils.toMatlabString(this.W));
        System.out.println(this.b);
        return true;
    }

    /* JADX WARN: Type inference failed for: r3v1, types: [double[], double[][]] */
    public double[] project(double[] dArr) {
        return this.W.times(new Matrix((double[][]) new double[]{dArr}).transpose()).getColumnPackedCopy();
    }
}
