package org.openimaj.ml.linear.learner.loss;

import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.mtj.SparseMatrixFactoryMTJ;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.openimaj.math.matrix.CFMatrixUtils;

/* loaded from: input_file:org/openimaj/ml/linear/learner/loss/MatSquareLossFunction.class */
public class MatSquareLossFunction extends LossFunction {
    Logger logger = LogManager.getLogger(MatSquareLossFunction.class);
    private SparseMatrixFactoryMTJ spf = SparseMatrixFactoryMTJ.INSTANCE;

    @Override // org.openimaj.ml.linear.learner.loss.LossFunction
    public Matrix gradient(Matrix matrix) {
        Matrix clone = matrix.clone();
        if (CFMatrixUtils.containsInfinity(this.X)) {
            throw new RuntimeException();
        }
        if (CFMatrixUtils.containsInfinity(matrix)) {
            throw new RuntimeException();
        }
        Matrix fastdot = CFMatrixUtils.fastdot(this.X, matrix);
        if (CFMatrixUtils.containsInfinity(fastdot)) {
            CFMatrixUtils.fastdot(this.X, matrix);
            throw new RuntimeException();
        }
        if (this.bias != null) {
            fastdot.plusEquals(this.bias);
        }
        CFMatrixUtils.fastminusEquals(fastdot, this.Y);
        if (CFMatrixUtils.containsInfinity(fastdot)) {
            throw new RuntimeException();
        }
        for (int i = 0; i < fastdot.getNumColumns(); i++) {
            CFMatrixUtils.fastsetcol(clone, i, this.X.getRow(i).scale(fastdot.getElement(i, i)).clone());
        }
        return clone;
    }

    @Override // org.openimaj.ml.linear.learner.loss.LossFunction
    public double eval(Matrix matrix) {
        Matrix clone = matrix == null ? this.X.clone() : CFMatrixUtils.fastdot(this.X, matrix);
        Matrix clone2 = clone.clone();
        if (this.bias != null) {
            clone.plusEquals(this.bias);
        }
        Matrix clone3 = clone.clone();
        clone.minusEquals(this.Y);
        double d = 0.0d;
        for (int i = 0; i < clone.getNumColumns(); i++) {
            double element = clone.getElement(i, i);
            d += element * element;
            this.logger.debug(String.format("yr=%d,y=%3.2f,v=%3.2f,v(no bias)=%2.5f,error=%2.5f,serror=%2.5f", Integer.valueOf(i), Double.valueOf(this.Y.getElement(i, i)), Double.valueOf(clone3.getElement(i, i)), Double.valueOf(clone2.getElement(i, i)), Double.valueOf(element), Double.valueOf(element * element)));
        }
        return d;
    }

    @Override // org.openimaj.ml.linear.learner.loss.LossFunction
    public boolean test_backtrack(Matrix matrix, Matrix matrix2, Matrix matrix3, double d) {
        Matrix minus = matrix3.minus(matrix);
        return eval(matrix3) <= (eval(matrix) + CFMatrixUtils.sum(CFMatrixUtils.fastdot(matrix2.transpose(), minus))) + ((0.5d * d) * minus.normFrobenius());
    }

    @Override // org.openimaj.ml.linear.learner.loss.LossFunction
    public boolean isMatrixLoss() {
        return true;
    }
}
