package org.openimaj.ml.linear.learner.matlib;

import ch.akuhn.matrix.Matrix;
import ch.akuhn.matrix.SparseMatrix;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.openimaj.io.ReadWriteableBinary;
import org.openimaj.math.matrix.DiagonalMatrix;
import org.openimaj.math.matrix.MatlibMatrixUtils;
import org.openimaj.ml.linear.learner.BilinearLearnerParameters;
import org.openimaj.ml.linear.learner.OnlineLearner;
import org.openimaj.ml.linear.learner.matlib.init.InitStrategy;
import org.openimaj.ml.linear.learner.matlib.init.SparseSingleValueInitStrat;
import org.openimaj.ml.linear.learner.matlib.loss.LossFunction;
import org.openimaj.ml.linear.learner.matlib.loss.MatLossFunction;
import org.openimaj.ml.linear.learner.matlib.regul.Regulariser;

/* loaded from: input_file:org/openimaj/ml/linear/learner/matlib/MatlibBilinearSparseOnlineLearner.class */
public class MatlibBilinearSparseOnlineLearner implements OnlineLearner<Matrix, Matrix>, ReadWriteableBinary {
    static Logger logger = LogManager.getLogger(MatlibBilinearSparseOnlineLearner.class);
    protected BilinearLearnerParameters params;
    protected Matrix w;
    protected Matrix u;
    protected LossFunction loss;
    protected Regulariser regul;
    protected Double lambda_w;
    protected Double lambda_u;
    protected Boolean biasMode;
    protected Matrix bias;
    protected Matrix diagX;
    protected Double eta0_u;
    protected Double eta0_w;
    private Boolean forceSparcity;
    private Boolean zStandardise;
    private boolean nodataseen;

    public MatlibBilinearSparseOnlineLearner() {
        this(new BilinearLearnerParameters());
    }

    public MatlibBilinearSparseOnlineLearner(BilinearLearnerParameters bilinearLearnerParameters) {
        this.params = bilinearLearnerParameters;
        reinitParams();
    }

    public void reinitParams() {
        this.loss = (LossFunction) this.params.getTyped(BilinearLearnerParameters.LOSS);
        this.regul = (Regulariser) this.params.getTyped(BilinearLearnerParameters.REGUL);
        this.lambda_w = (Double) this.params.getTyped(BilinearLearnerParameters.LAMBDA_W);
        this.lambda_u = (Double) this.params.getTyped(BilinearLearnerParameters.LAMBDA_U);
        this.biasMode = (Boolean) this.params.getTyped(BilinearLearnerParameters.BIAS);
        this.eta0_u = (Double) this.params.getTyped(BilinearLearnerParameters.ETA0_U);
        this.eta0_w = (Double) this.params.getTyped(BilinearLearnerParameters.ETA0_W);
        this.forceSparcity = (Boolean) this.params.getTyped(BilinearLearnerParameters.FORCE_SPARCITY);
        this.zStandardise = (Boolean) this.params.getTyped(BilinearLearnerParameters.Z_STANDARDISE);
        if (!this.loss.isMatrixLoss()) {
            this.loss = new MatLossFunction(this.loss);
        }
        this.nodataseen = true;
    }

    private void initParams(Matrix matrix, Matrix matrix2, int i, int i2, int i3) {
        InitStrategy initStrat = getInitStrat(BilinearLearnerParameters.WINITSTRAT, matrix, matrix2);
        InitStrategy initStrat2 = getInitStrat(BilinearLearnerParameters.UINITSTRAT, matrix, matrix2);
        this.w = initStrat.init(i, i3);
        this.u = initStrat2.init(i2, i3);
        this.bias = SparseMatrix.sparse(i3, i3);
        if (this.biasMode.booleanValue()) {
            this.bias = getInitStrat(BilinearLearnerParameters.BIASINITSTRAT, matrix, matrix2).init(i3, i3);
            this.diagX = new DiagonalMatrix(i3, 1.0d);
        }
    }

    private InitStrategy getInitStrat(String str, Matrix matrix, Matrix matrix2) {
        return (InitStrategy) this.params.getTyped(str);
    }

    @Override // org.openimaj.ml.linear.learner.OnlineLearner
    public void process(Matrix matrix, Matrix matrix2) {
        Matrix dotProductTransposeTranspose;
        double normF;
        double normF2;
        double d;
        double d2;
        Integer num;
        int rowCount = matrix.rowCount();
        int columnCount = matrix.columnCount();
        int columnCount2 = matrix2.columnCount();
        if (this.w == null) {
            initParams(matrix, matrix2, rowCount, columnCount, columnCount2);
        }
        double doubleValue = 1.0d - ((Double) this.params.getTyped(BilinearLearnerParameters.DAMPENING)).doubleValue();
        logger.debug("... dampening w, u and bias by: " + doubleValue);
        MatlibMatrixUtils.scaleInplace(this.w, doubleValue);
        MatlibMatrixUtils.scaleInplace(this.u, doubleValue);
        if (this.biasMode.booleanValue()) {
            MatlibMatrixUtils.scaleInplace(this.bias, doubleValue);
        }
        this.loss.setY(expandY(matrix2));
        int i = 0;
        do {
            if (this.biasMode.booleanValue()) {
                this.loss.setBias(this.bias);
            }
            i++;
            double etat = etat(i, this.eta0_u.doubleValue());
            double etat2 = etat(i, this.eta0_w.doubleValue());
            double lambdat = lambdat(i, this.lambda_u.doubleValue());
            double lambdat2 = lambdat(i, this.lambda_w.doubleValue());
            if (this.nodataseen) {
                this.nodataseen = false;
                dotProductTransposeTranspose = MatlibMatrixUtils.dotProductTranspose(new SparseSingleValueInitStrat(1.0d).init(this.u.columnCount(), this.u.rowCount()), matrix);
            } else {
                dotProductTransposeTranspose = MatlibMatrixUtils.dotProductTransposeTranspose(this.u, matrix);
            }
            if (this.zStandardise.booleanValue()) {
            }
            this.loss.setX(dotProductTransposeTranspose);
            Matrix updateW = updateW(this.w, etat2, lambdat2);
            this.loss.setX(MatlibMatrixUtils.transposeDotProduct(updateW, matrix));
            Matrix updateU = updateU(this.u, etat, lambdat);
            double normF3 = MatlibMatrixUtils.normF(MatlibMatrixUtils.minus(updateW, this.w));
            normF = MatlibMatrixUtils.normF(this.w);
            double normF4 = MatlibMatrixUtils.normF(MatlibMatrixUtils.minus(updateU, this.u));
            normF2 = MatlibMatrixUtils.normF(this.u);
            double d3 = 0.0d;
            if (normF2 != 0.0d) {
                d3 = normF4 / normF2;
            }
            if (normF != 0.0d) {
                d3 = normF3 / normF;
            }
            double d4 = 0.0d;
            double d5 = d3 + 0.0d;
            d = 0.0d;
            if (this.biasMode.booleanValue()) {
                Matrix dotProduct = MatlibMatrixUtils.dotProduct(MatlibMatrixUtils.dotProductTransposeTranspose(updateU, matrix), updateW);
                MatlibMatrixUtils.plusInplace(dotProduct, this.bias);
                this.loss.setBias(null);
                this.loss.setX(this.diagX);
                Matrix updateBias = updateBias(this.loss.gradient(dotProduct), biasEtat(i));
                double normF5 = MatlibMatrixUtils.normF(MatlibMatrixUtils.minus(updateBias, this.bias));
                d = MatlibMatrixUtils.normF(this.bias);
                if (d != 0.0d) {
                    d4 = normF5 / d;
                }
                this.bias = updateBias;
                d2 = (d5 + d4) / 3.0d;
            } else {
                d2 = d5 / 2.0d;
            }
            Double d6 = (Double) this.params.getTyped(BilinearLearnerParameters.BICONVEX_TOL);
            num = (Integer) this.params.getTyped(BilinearLearnerParameters.BICONVEX_MAXITER);
            if (i % 3 == 0) {
                logger.debug(String.format("Iter: %d. Last Ratio: %2.3f", Integer.valueOf(i), Double.valueOf(d2)));
                logger.debug("W row sparcity: " + MatlibMatrixUtils.sparsity(this.w));
                logger.debug("U row sparcity: " + MatlibMatrixUtils.sparsity(this.u));
                logger.debug("Total U magnitude: " + normF2);
                logger.debug("Total W magnitude: " + normF);
                logger.debug("Total Bias: " + d);
            }
            if (d6.doubleValue() < 0.0d || d2 < d6.doubleValue()) {
                break;
            }
        } while (i < num.intValue());
        logger.debug("tolerance reached after iteration: " + i);
        logger.debug("W row sparcity: " + MatlibMatrixUtils.sparsity(this.w));
        logger.debug("U row sparcity: " + MatlibMatrixUtils.sparsity(this.u));
        logger.debug("Total U magnitude: " + normF2);
        logger.debug("Total W magnitude: " + normF);
        logger.debug("Total Bias: " + d);
    }

    protected Matrix updateBias(Matrix matrix, double d) {
        return MatlibMatrixUtils.minus(this.bias, MatlibMatrixUtils.scaleInplace(matrix, d));
    }

    protected Matrix updateW(Matrix matrix, double d, double d2) {
        Matrix gradient = this.loss.gradient(matrix);
        MatlibMatrixUtils.scaleInplace(gradient, d);
        return this.regul.prox(MatlibMatrixUtils.minus(matrix, gradient), d2);
    }

    protected Matrix updateU(Matrix matrix, double d, double d2) {
        Matrix gradient = this.loss.gradient(matrix);
        MatlibMatrixUtils.scaleInplace(gradient, d);
        return this.regul.prox(MatlibMatrixUtils.minus(matrix, gradient), d2);
    }

    private double lambdat(int i, double d) {
        return d / i;
    }

    public static SparseMatrix expandY(Matrix matrix) {
        int columnCount = matrix.columnCount();
        SparseMatrix sparse = SparseMatrix.sparse(columnCount, columnCount);
        for (int i = 0; i < columnCount; i++) {
            for (int i2 = 0; i2 < columnCount; i2++) {
                if (i2 == i) {
                    sparse.put(i, i2, matrix.get(0, i2));
                } else {
                    sparse.put(i, i2, Double.NaN);
                }
            }
        }
        return sparse;
    }

    private double biasEtat(int i) {
        return ((Double) this.params.getTyped(BilinearLearnerParameters.ETA0_BIAS)).doubleValue() / Math.sqrt(i);
    }

    private double etat(int i, double d) {
        return eta(d) / Math.sqrt(Math.ceil(i / ((Integer) this.params.getTyped(BilinearLearnerParameters.ETASTEPS)).intValue()));
    }

    private double eta(double d) {
        return d;
    }

    public BilinearLearnerParameters getParams() {
        return this.params;
    }

    public Matrix getU() {
        return this.u;
    }

    public Matrix getW() {
        return this.w;
    }

    public Matrix getBias() {
        if (this.biasMode.booleanValue()) {
            return this.bias;
        }
        return null;
    }

    public void addU(int i) {
        if (this.u == null) {
            return;
        }
        this.u = MatlibMatrixUtils.vstack(new Matrix[]{this.u, getInitStrat(BilinearLearnerParameters.EXPANDEDUINITSTRAT, null, null).init(i, this.u.columnCount())});
    }

    public void addW(int i) {
        if (this.w == null) {
            return;
        }
        this.w = MatlibMatrixUtils.vstack(new Matrix[]{this.w, getInitStrat(BilinearLearnerParameters.EXPANDEDWINITSTRAT, null, null).init(i, this.w.columnCount())});
    }

    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public MatlibBilinearSparseOnlineLearner m16clone() {
        MatlibBilinearSparseOnlineLearner matlibBilinearSparseOnlineLearner = new MatlibBilinearSparseOnlineLearner(getParams());
        matlibBilinearSparseOnlineLearner.u = MatlibMatrixUtils.copy(this.u);
        matlibBilinearSparseOnlineLearner.w = MatlibMatrixUtils.copy(this.w);
        if (this.biasMode.booleanValue()) {
            matlibBilinearSparseOnlineLearner.bias = MatlibMatrixUtils.copy(this.bias);
        }
        return matlibBilinearSparseOnlineLearner;
    }

    public void setU(Matrix matrix) {
        this.u = matrix;
    }

    public void setW(Matrix matrix) {
        this.w = matrix;
    }

    public void readBinary(DataInput dataInput) throws IOException {
        int readInt = dataInput.readInt();
        int readInt2 = dataInput.readInt();
        int readInt3 = dataInput.readInt();
        this.w = SparseMatrix.sparse(readInt, readInt3);
        for (int i = 0; i < readInt3; i++) {
            for (int i2 = 0; i2 < readInt; i2++) {
                double readDouble = dataInput.readDouble();
                if (readDouble != 0.0d) {
                    this.w.put(i2, i, readDouble);
                }
            }
        }
        this.u = SparseMatrix.sparse(readInt2, readInt3);
        for (int i3 = 0; i3 < readInt3; i3++) {
            for (int i4 = 0; i4 < readInt2; i4++) {
                double readDouble2 = dataInput.readDouble();
                if (readDouble2 != 0.0d) {
                    this.u.put(i4, i3, readDouble2);
                }
            }
        }
        this.bias = SparseMatrix.sparse(readInt3, readInt3);
        for (int i5 = 0; i5 < readInt3; i5++) {
            for (int i6 = 0; i6 < readInt3; i6++) {
                double readDouble3 = dataInput.readDouble();
                if (readDouble3 != 0.0d) {
                    this.bias.put(i5, i6, readDouble3);
                }
            }
        }
    }

    public byte[] binaryHeader() {
        return "".getBytes();
    }

    public void writeBinary(DataOutput dataOutput) throws IOException {
        dataOutput.writeInt(this.w.rowCount());
        dataOutput.writeInt(this.u.rowCount());
        dataOutput.writeInt(this.u.columnCount());
        for (double d : this.w.asColumnMajorArray()) {
            dataOutput.writeDouble(d);
        }
        for (double d2 : this.u.asColumnMajorArray()) {
            dataOutput.writeDouble(d2);
        }
        for (double d3 : this.bias.asColumnMajorArray()) {
            dataOutput.writeDouble(d3);
        }
    }

    @Override // org.openimaj.ml.linear.learner.OnlineLearner
    public Matrix predict(Matrix matrix) {
        Matrix dotProduct = MatlibMatrixUtils.dotProduct(MatlibMatrixUtils.dotProduct(MatlibMatrixUtils.transpose(this.u), MatlibMatrixUtils.transpose(matrix)), this.w);
        if (this.biasMode.booleanValue()) {
            MatlibMatrixUtils.plusInplace(dotProduct, this.bias);
        }
        return new DiagonalMatrix(dotProduct);
    }
}
