package org.openimaj.ml.linear.learner.loss;

import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.mtj.SparseMatrix;
import gov.sandia.cognition.math.matrix.mtj.SparseMatrixFactoryMTJ;

/* loaded from: input_file:org/openimaj/ml/linear/learner/loss/MatLossFunction.class */
public class MatLossFunction extends LossFunction {
    private LossFunction f;
    private SparseMatrixFactoryMTJ spf = SparseMatrixFactoryMTJ.INSTANCE;

    public MatLossFunction(LossFunction lossFunction) {
        this.f = lossFunction;
    }

    @Override // org.openimaj.ml.linear.learner.loss.LossFunction
    public void setX(Matrix matrix) {
        super.setX(matrix);
        this.f.setX(matrix);
    }

    @Override // org.openimaj.ml.linear.learner.loss.LossFunction
    public void setY(Matrix matrix) {
        super.setY(matrix);
        this.f.setY(matrix);
    }

    @Override // org.openimaj.ml.linear.learner.loss.LossFunction
    public void setBias(Matrix matrix) {
        super.setBias(matrix);
        this.f.setBias(matrix);
    }

    @Override // org.openimaj.ml.linear.learner.loss.LossFunction
    public Matrix gradient(Matrix matrix) {
        SparseMatrix createMatrix = this.spf.createMatrix(matrix.getNumRows(), matrix.getNumColumns());
        int numRows = this.Y.getNumRows() - 1;
        int numRows2 = matrix.getNumRows() - 1;
        for (int i = 0; i < this.Y.getNumColumns(); i++) {
            this.f.setY(this.Y.getSubMatrix(0, numRows, i, i));
            if (this.bias != null) {
                this.f.setBias(this.bias.getSubMatrix(0, numRows, i, i));
            }
            createMatrix.setSubMatrix(0, i, this.f.gradient(matrix.getSubMatrix(0, numRows2, i, i)));
        }
        return createMatrix;
    }

    @Override // org.openimaj.ml.linear.learner.loss.LossFunction
    public double eval(Matrix matrix) {
        this.f.setBias(this.bias);
        return 0.0d + this.f.eval(matrix);
    }
}
