package org.openimaj.ml.linear.learner.loss;

import gov.sandia.cognition.math.matrix.Matrix;
import org.apache.log4j.Logger;

/* loaded from: input_file:org/openimaj/ml/linear/learner/loss/SquareMissingLossFunction.class */
public class SquareMissingLossFunction extends LossFunction {
    Logger logger = Logger.getLogger(SquareMissingLossFunction.class);

    @Override // org.openimaj.ml.linear.learner.loss.LossFunction
    public Matrix gradient(Matrix matrix) {
        Matrix minus = this.X.times(matrix).minus(this.Y);
        if (this.bias != null) {
            minus.plusEquals(this.bias);
        }
        for (int i = 0; i < this.Y.getNumRows(); i++) {
            if (Double.isNaN(this.Y.getElement(i, 0))) {
                minus.setElement(i, 0, 0.0d);
            }
        }
        return this.X.transpose().times(minus);
    }

    @Override // org.openimaj.ml.linear.learner.loss.LossFunction
    public double eval(Matrix matrix) {
        Matrix times = matrix == null ? this.X : this.X.times(matrix);
        Matrix clone = times.clone();
        if (this.bias != null) {
            times.plusEquals(this.bias);
        }
        double d = 0.0d;
        for (int i = 0; i < this.Y.getNumRows(); i++) {
            for (int i2 = 0; i2 < this.Y.getNumColumns(); i2++) {
                double element = this.Y.getElement(i, i2);
                if (!Double.isNaN(element)) {
                    double element2 = times.getElement(i, i2);
                    double element3 = clone.getElement(i, i2);
                    double d2 = element - element2;
                    this.logger.debug(String.format("yr=%d,y=%3.2f,v=%3.2f,v(no bias)=%2.5f,delta=%2.5f", Integer.valueOf(i), Double.valueOf(element), Double.valueOf(element2), Double.valueOf(element3), Double.valueOf(d2)));
                    d += d2 * d2;
                }
            }
        }
        return d;
    }
}
