package org.openimaj.workinprogress;

import Jama.Matrix;
import java.util.Arrays;
import java.util.Iterator;
import java.util.Random;
import org.openimaj.data.AbstractDataSource;
import org.openimaj.math.matrix.ThinSingularValueDecomposition;
import org.openimaj.workinprogress.optimisation.DifferentiableObjectiveFunction;
import org.openimaj.workinprogress.optimisation.EpochAnnealedLearningRate;
import org.openimaj.workinprogress.optimisation.SGD;
import org.openimaj.workinprogress.optimisation.params.KeyedParameters;

/* loaded from: input_file:org/openimaj/workinprogress/GD_SVD2.class */
public class GD_SVD2 {
    private static final int maxEpochs = 300;
    private static final double initialLearningRate = 0.01d;
    private static final double annealingRate = 30.0d;
    Matrix UprimeM;
    Matrix VprimeM;
    private Matrix UM;
    private Matrix VM;
    private Matrix SM;

    /* loaded from: input_file:org/openimaj/workinprogress/GD_SVD2$GD_SVD2_DOF.class */
    static class GD_SVD2_DOF implements DifferentiableObjectiveFunction<GD_SVD2, double[], KeyedParameters<String>> {
        public int k;

        GD_SVD2_DOF() {
        }

        @Override // org.openimaj.workinprogress.optimisation.ObjectiveFunction
        public double value(GD_SVD2 gd_svd2, double[] dArr) {
            double predict = ((int) dArr[2]) - gd_svd2.predict((int) dArr[0], (int) dArr[1], this.k);
            return predict * predict;
        }

        @Override // org.openimaj.workinprogress.optimisation.DifferentiableObjectiveFunction
        public KeyedParameters<String> derivative(GD_SVD2 gd_svd2, double[] dArr) {
            int i = (int) dArr[0];
            int i2 = (int) dArr[1];
            double d = dArr[2];
            double[][] array = gd_svd2.UprimeM.getArray();
            double[][] array2 = gd_svd2.VprimeM.getArray();
            double predict = d - gd_svd2.predict(i, i2, this.k);
            double d2 = array[i][this.k];
            double d3 = array2[i2][this.k];
            System.out.println("Error: " + predict + " " + d);
            KeyedParameters<String> keyedParameters = new KeyedParameters<>();
            keyedParameters.set("i" + i, predict * d3);
            keyedParameters.set("j" + i2, predict * d2);
            return keyedParameters;
        }

        @Override // org.openimaj.workinprogress.optimisation.DifferentiableObjectiveFunction
        public void updateModel(GD_SVD2 gd_svd2, KeyedParameters<String> keyedParameters) {
            double[][] array = gd_svd2.UprimeM.getArray();
            double[][] array2 = gd_svd2.VprimeM.getArray();
            Iterator<KeyedParameters.ObjectDoubleEntry<String>> it = keyedParameters.iterator();
            while (it.hasNext()) {
                KeyedParameters.ObjectDoubleEntry<String> next = it.next();
                char charAt = next.key.charAt(0);
                int parseInt = Integer.parseInt(next.key.substring(1));
                if (charAt == 'i') {
                    double[] dArr = array[parseInt];
                    int i = this.k;
                    dArr[i] = dArr[i] + next.value;
                } else {
                    double[] dArr2 = array2[parseInt];
                    int i2 = this.k;
                    dArr2[i2] = dArr2[i2] + next.value;
                }
            }
        }
    }

    protected double predict(int i, int i2, int i3) {
        double[][] array = this.UprimeM.getArray();
        double[][] array2 = this.VprimeM.getArray();
        double d = 0.0d;
        for (int i4 = 0; i4 <= i3; i4++) {
            d += array[i][i4] * array2[i2][i4];
        }
        return d;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public GD_SVD2(Matrix matrix, int i) {
        Random random = new Random(0L);
        double sqrt = 1.0d / Math.sqrt(i);
        int rowDimension = matrix.getRowDimension();
        int columnDimension = matrix.getColumnDimension();
        this.UprimeM = new Matrix(rowDimension, i);
        this.VprimeM = new Matrix(columnDimension, i);
        double[][] array = this.UprimeM.getArray();
        double[][] array2 = this.VprimeM.getArray();
        final double[][] array3 = matrix.getArray();
        SGD sgd = new SGD();
        sgd.fcn = new GD_SVD2_DOF();
        sgd.batchSize = 1;
        sgd.maxEpochs = maxEpochs;
        sgd.learningRate = new EpochAnnealedLearningRate(initialLearningRate, maxEpochs);
        sgd.model = this;
        ((GD_SVD2_DOF) sgd.fcn).k = 0;
        while (((GD_SVD2_DOF) sgd.fcn).k < i) {
            for (int i2 = 0; i2 < rowDimension; i2++) {
                array[i2][((GD_SVD2_DOF) sgd.fcn).k] = random.nextGaussian() * sqrt;
            }
            for (int i3 = 0; i3 < columnDimension; i3++) {
                array2[i3][((GD_SVD2_DOF) sgd.fcn).k] = random.nextGaussian() * sqrt;
            }
            sgd.train(new AbstractDataSource<double[]>() { // from class: org.openimaj.workinprogress.GD_SVD2.1
                public void getData(int i4, int i5, double[][] dArr) {
                    int i6 = i4;
                    int i7 = 0;
                    while (i6 < i5) {
                        int length = i6 / array3[0].length;
                        int length2 = i6 % array3[0].length;
                        dArr[i7][0] = length;
                        dArr[i7][1] = length2;
                        dArr[i7][2] = array3[length][length2];
                        i6++;
                        i7++;
                    }
                }

                /* renamed from: getData, reason: merged with bridge method [inline-methods] */
                public double[] m47getData(int i4) {
                    int length = i4 / array3[0].length;
                    int length2 = i4 % array3[0].length;
                    return new double[]{length, length2, array3[length][length2]};
                }

                public int numDimensions() {
                    return 3;
                }

                public int size() {
                    return array3[0].length * array3.length;
                }

                /* renamed from: createTemporaryArray, reason: merged with bridge method [inline-methods] */
                public double[][] m46createTemporaryArray(int i4) {
                    return new double[i4][3];
                }
            });
            ((GD_SVD2_DOF) sgd.fcn).k++;
        }
        this.UM = new Matrix(rowDimension, i);
        double[][] array4 = this.UM.getArray();
        this.SM = new Matrix(i, i);
        double[][] array5 = this.SM.getArray();
        this.VM = new Matrix(i, columnDimension);
        double[][] array6 = this.VM.getArray();
        for (int i4 = 0; i4 < i; i4++) {
            double d = 0.0d;
            double d2 = 0.0d;
            for (int i5 = 0; i5 < rowDimension; i5++) {
                d += array[i5][i4] * array[i5][i4];
            }
            for (int i6 = 0; i6 < columnDimension; i6++) {
                d2 += array2[i6][i4] * array2[i6][i4];
            }
            double sqrt2 = Math.sqrt(d);
            double sqrt3 = Math.sqrt(d2);
            for (int i7 = 0; i7 < rowDimension; i7++) {
                array4[i7][i4] = array[i7][i4] / sqrt2;
            }
            for (int i8 = 0; i8 < columnDimension; i8++) {
                array6[i4][i8] = array2[i8][i4] / sqrt3;
            }
            array5[i4][i4] = sqrt2 * sqrt3;
        }
    }

    /* JADX WARN: Type inference failed for: r2v1, types: [double[], double[][]] */
    public static void main(String[] strArr) {
        Matrix matrix = new Matrix((double[][]) new double[]{new double[]{0.5d, 0.4d}, new double[]{0.1d, 0.7d}});
        new GD_SVD2(matrix, 2).SM.print(5, 5);
        System.out.println(Arrays.toString(new ThinSingularValueDecomposition(matrix, 2).S));
    }
}
