package org.openimaj.ml.linear.learner;

import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.mtj.SparseMatrix;
import gov.sandia.cognition.math.matrix.mtj.SparseMatrixFactoryMTJ;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import org.apache.log4j.Logger;
import org.openimaj.io.ReadWriteableBinary;
import org.openimaj.math.matrix.CFMatrixUtils;
import org.openimaj.ml.linear.learner.init.ContextAwareInitStrategy;
import org.openimaj.ml.linear.learner.init.InitStrategy;
import org.openimaj.ml.linear.learner.loss.LossFunction;
import org.openimaj.ml.linear.learner.loss.MatLossFunction;
import org.openimaj.ml.linear.learner.regul.Regulariser;

/* loaded from: input_file:org/openimaj/ml/linear/learner/BilinearSparseOnlineLearner.class */
public class BilinearSparseOnlineLearner implements OnlineLearner<Matrix, Matrix>, ReadWriteableBinary {
    static Logger logger = Logger.getLogger(BilinearSparseOnlineLearner.class);
    protected BilinearLearnerParameters params;
    protected Matrix w;
    protected Matrix u;
    protected SparseMatrixFactoryMTJ smf;
    protected LossFunction loss;
    protected Regulariser regul;
    protected Double lambda_w;
    protected Double lambda_u;
    protected Boolean biasMode;
    protected Matrix bias;
    protected Matrix diagX;
    protected Double eta0_u;
    protected Double eta0_w;
    private Boolean forceSparcity;

    public BilinearSparseOnlineLearner() {
        this(new BilinearLearnerParameters());
    }

    public BilinearSparseOnlineLearner(BilinearLearnerParameters bilinearLearnerParameters) {
        this.smf = SparseMatrixFactoryMTJ.INSTANCE;
        this.params = bilinearLearnerParameters;
        reinitParams();
    }

    public void reinitParams() {
        this.loss = (LossFunction) this.params.getTyped(BilinearLearnerParameters.LOSS);
        this.regul = (Regulariser) this.params.getTyped(BilinearLearnerParameters.REGUL);
        this.lambda_w = (Double) this.params.getTyped(BilinearLearnerParameters.LAMBDA_W);
        this.lambda_u = (Double) this.params.getTyped(BilinearLearnerParameters.LAMBDA_U);
        this.biasMode = (Boolean) this.params.getTyped(BilinearLearnerParameters.BIAS);
        this.eta0_u = (Double) this.params.getTyped(BilinearLearnerParameters.ETA0_U);
        this.eta0_w = (Double) this.params.getTyped(BilinearLearnerParameters.ETA0_W);
        this.forceSparcity = (Boolean) this.params.getTyped(BilinearLearnerParameters.FORCE_SPARCITY);
        this.loss = new MatLossFunction(this.loss);
    }

    private void initParams(Matrix matrix, Matrix matrix2, int i, int i2, int i3) {
        InitStrategy initStrat = getInitStrat(BilinearLearnerParameters.WINITSTRAT, matrix, matrix2);
        InitStrategy initStrat2 = getInitStrat(BilinearLearnerParameters.UINITSTRAT, matrix, matrix2);
        this.w = initStrat.init(i, i3);
        this.u = initStrat2.init(i2, i3);
        this.bias = this.smf.createMatrix(i3, i3);
        if (this.biasMode.booleanValue()) {
            this.bias = getInitStrat(BilinearLearnerParameters.BIASINITSTRAT, matrix, matrix2).init(i3, i3);
            this.diagX = this.smf.createIdentity(i3, i3);
        }
    }

    private InitStrategy getInitStrat(String str, Matrix matrix, Matrix matrix2) {
        InitStrategy initStrategy = (InitStrategy) this.params.getTyped(str);
        if (!(initStrategy instanceof ContextAwareInitStrategy)) {
            return initStrategy;
        }
        ContextAwareInitStrategy contextAwareInitStrategy = (ContextAwareInitStrategy) this.params.getTyped(str);
        contextAwareInitStrategy.setLearner(this);
        contextAwareInitStrategy.setContext(matrix, matrix2);
        return contextAwareInitStrategy;
    }

    @Override // org.openimaj.ml.linear.learner.OnlineLearner
    public void process(Matrix matrix, Matrix matrix2) {
        double d;
        Integer num;
        int numRows = matrix.getNumRows();
        int numColumns = matrix.getNumColumns();
        int numColumns2 = matrix2.getNumColumns();
        if (this.w == null) {
            initParams(matrix, matrix2, numRows, numColumns, numColumns2);
        }
        double doubleValue = 1.0d - ((Double) this.params.getTyped(BilinearLearnerParameters.DAMPENING)).doubleValue();
        logger.debug("... dampening w, u and bias by: " + doubleValue);
        this.w.scaleEquals(doubleValue);
        this.u.scaleEquals(doubleValue);
        if (this.biasMode.booleanValue()) {
            this.bias.scaleEquals(doubleValue);
        }
        this.loss.setY(expandY(matrix2));
        int i = 0;
        do {
            if (this.biasMode.booleanValue()) {
                this.loss.setBias(this.bias);
            }
            i++;
            double etat = etat(i, this.eta0_u.doubleValue());
            double etat2 = etat(i, this.eta0_w.doubleValue());
            double lambdat = lambdat(i, this.lambda_u.doubleValue());
            double lambdat2 = lambdat(i, this.lambda_w.doubleValue());
            this.loss.setX(matrix.transpose().times(this.w).transpose());
            Matrix updateU = updateU(this.u, etat, lambdat);
            this.loss.setX(updateU.transpose().times(matrix.transpose()));
            Matrix updateW = updateW(this.w, etat2, lambdat2);
            double absSum = CFMatrixUtils.absSum(updateW.minus(this.w));
            double absSum2 = CFMatrixUtils.absSum(this.w);
            double absSum3 = CFMatrixUtils.absSum(updateU.minus(this.u));
            double absSum4 = CFMatrixUtils.absSum(this.u);
            double d2 = 0.0d;
            if (absSum4 != 0.0d) {
                d2 = absSum3 / absSum4;
            }
            if (absSum2 != 0.0d) {
                d2 = absSum / absSum2;
            }
            double d3 = d2 + 0.0d;
            if (this.biasMode.booleanValue()) {
                Matrix plus = updateU.transpose().times(matrix.transpose()).times(updateW).plus(this.bias);
                this.loss.setBias(null);
                this.loss.setX(this.diagX);
                Matrix updateBias = updateBias(this.loss.gradient(plus), biasEtat(i));
                double absSum5 = CFMatrixUtils.absSum(updateBias.minus(this.bias)) / CFMatrixUtils.absSum(this.bias);
                this.bias = updateBias;
                d = (d3 + absSum5) / 3.0d;
            } else {
                d = d3 / 2.0d;
            }
            if (this.forceSparcity.booleanValue()) {
                this.w = this.smf.copyMatrix(updateW);
                this.u = this.smf.copyMatrix(updateU);
            } else {
                this.w = updateW;
                this.u = updateU;
            }
            Double d4 = (Double) this.params.getTyped(BilinearLearnerParameters.BICONVEX_TOL);
            num = (Integer) this.params.getTyped(BilinearLearnerParameters.BICONVEX_MAXITER);
            if (i % 3 == 0) {
                logger.debug(String.format("Iter: %d. Last Ratio: %2.3f", Integer.valueOf(i), Double.valueOf(d)));
            }
            if (d4.doubleValue() < 0.0d || d < d4.doubleValue()) {
                break;
            }
        } while (i < num.intValue());
        logger.debug("tolerance reached after iteration: " + i);
    }

    protected Matrix updateBias(Matrix matrix, double d) {
        return this.bias.minus(CFMatrixUtils.timesInplace(matrix, d));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Matrix updateW(Matrix matrix, double d, double d2) {
        Matrix gradient = this.loss.gradient(matrix);
        CFMatrixUtils.timesInplace(gradient, d);
        return this.regul.prox(matrix.minus(gradient), d2);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Matrix updateU(Matrix matrix, double d, double d2) {
        Matrix gradient = this.loss.gradient(matrix);
        CFMatrixUtils.timesInplace(gradient, d);
        return this.regul.prox(matrix.minus(gradient), d2);
    }

    private double lambdat(int i, double d) {
        return d / i;
    }

    public static SparseMatrix expandY(Matrix matrix) {
        int numColumns = matrix.getNumColumns();
        SparseMatrix createMatrix = SparseMatrixFactoryMTJ.INSTANCE.createMatrix(numColumns, numColumns);
        for (int i = 0; i < numColumns; i++) {
            for (int i2 = 0; i2 < numColumns; i2++) {
                if (i2 == i) {
                    createMatrix.setElement(i, i2, matrix.getElement(0, i2));
                } else {
                    createMatrix.setElement(i, i2, Double.NaN);
                }
            }
        }
        return createMatrix;
    }

    private double biasEtat(int i) {
        return ((Double) this.params.getTyped(BilinearLearnerParameters.ETA0_BIAS)).doubleValue() / Math.sqrt(i);
    }

    private double etat(int i, double d) {
        return eta(d) / Math.sqrt(Math.ceil(i / ((Integer) this.params.getTyped(BilinearLearnerParameters.ETASTEPS)).intValue()));
    }

    private double eta(double d) {
        return d;
    }

    public BilinearLearnerParameters getParams() {
        return this.params;
    }

    public Matrix getU() {
        return this.u;
    }

    public Matrix getW() {
        return this.w;
    }

    public Matrix getBias() {
        if (this.biasMode.booleanValue()) {
            return this.bias;
        }
        return null;
    }

    public void addU(int i) {
        if (this.u == null) {
            return;
        }
        this.u = CFMatrixUtils.vstack(new Matrix[]{this.u, getInitStrat(BilinearLearnerParameters.EXPANDEDUINITSTRAT, null, null).init(i, this.u.getNumColumns())});
    }

    public void addW(int i) {
        if (this.w == null) {
            return;
        }
        this.w = CFMatrixUtils.vstack(new Matrix[]{this.w, getInitStrat(BilinearLearnerParameters.EXPANDEDWINITSTRAT, null, null).init(i, this.w.getNumColumns())});
    }

    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public BilinearSparseOnlineLearner m13clone() {
        BilinearSparseOnlineLearner bilinearSparseOnlineLearner = new BilinearSparseOnlineLearner(getParams());
        bilinearSparseOnlineLearner.u = this.u.clone();
        bilinearSparseOnlineLearner.w = this.w.clone();
        if (this.biasMode.booleanValue()) {
            bilinearSparseOnlineLearner.bias = this.bias.clone();
        }
        return bilinearSparseOnlineLearner;
    }

    public void setU(Matrix matrix) {
        this.u = matrix;
    }

    public void setW(Matrix matrix) {
        this.w = matrix;
    }

    public void readBinary(DataInput dataInput) throws IOException {
        int readInt = dataInput.readInt();
        int readInt2 = dataInput.readInt();
        int readInt3 = dataInput.readInt();
        this.w = SparseMatrixFactoryMTJ.INSTANCE.createMatrix(readInt, readInt3);
        for (int i = 0; i < readInt3; i++) {
            for (int i2 = 0; i2 < readInt; i2++) {
                double readDouble = dataInput.readDouble();
                if (readDouble != 0.0d) {
                    this.w.setElement(i2, i, readDouble);
                }
            }
        }
        this.u = SparseMatrixFactoryMTJ.INSTANCE.createMatrix(readInt2, readInt3);
        for (int i3 = 0; i3 < readInt3; i3++) {
            for (int i4 = 0; i4 < readInt2; i4++) {
                double readDouble2 = dataInput.readDouble();
                if (readDouble2 != 0.0d) {
                    this.u.setElement(i4, i3, readDouble2);
                }
            }
        }
        this.bias = SparseMatrixFactoryMTJ.INSTANCE.createMatrix(readInt3, readInt3);
        for (int i5 = 0; i5 < readInt3; i5++) {
            for (int i6 = 0; i6 < readInt3; i6++) {
                double readDouble3 = dataInput.readDouble();
                if (readDouble3 != 0.0d) {
                    this.bias.setElement(i5, i6, readDouble3);
                }
            }
        }
    }

    public byte[] binaryHeader() {
        return "".getBytes();
    }

    public void writeBinary(DataOutput dataOutput) throws IOException {
        dataOutput.writeInt(this.w.getNumRows());
        dataOutput.writeInt(this.u.getNumRows());
        dataOutput.writeInt(this.u.getNumColumns());
        for (double d : CFMatrixUtils.getData(this.w)) {
            dataOutput.writeDouble(d);
        }
        for (double d2 : CFMatrixUtils.getData(this.u)) {
            dataOutput.writeDouble(d2);
        }
        for (double d3 : CFMatrixUtils.getData(this.bias)) {
            dataOutput.writeDouble(d3);
        }
    }

    @Override // org.openimaj.ml.linear.learner.OnlineLearner
    public Matrix predict(Matrix matrix) {
        Matrix times = this.u.transpose().times(matrix.transpose()).times(this.w);
        if (this.biasMode.booleanValue()) {
            times.plusEquals(this.bias);
        }
        Vector diag = CFMatrixUtils.diag(times);
        Matrix createIdentity = SparseMatrixFactoryMTJ.INSTANCE.createIdentity(1, diag.getDimensionality());
        createIdentity.setRow(0, diag);
        return createIdentity;
    }
}
