package org.openimaj.math.matrix;

import ch.akuhn.matrix.Matrix;
import ch.akuhn.matrix.Vector;

/* loaded from: input_file:org/openimaj/math/matrix/DiagonalMatrix.class */
public class DiagonalMatrix extends Matrix {
    private double[] vals;

    public DiagonalMatrix(int i) {
        this.vals = new double[i];
    }

    public DiagonalMatrix(double[][] dArr) {
        this(Math.min(dArr.length, dArr[0].length));
        for (int i = 0; i < this.vals.length; i++) {
            this.vals[i] = dArr[i][i];
        }
    }

    public DiagonalMatrix(Matrix matrix) {
        this(Math.min(matrix.rowCount(), matrix.columnCount()));
        for (int i = 0; i < this.vals.length; i++) {
            this.vals[i] = matrix.get(i, i);
        }
    }

    public Vector mult(Vector vector) {
        double[] dArr = new double[columnCount()];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = this.vals[i] * vector.get(i);
        }
        return Vector.wrap(dArr);
    }

    public Vector transposeMultiply(Vector vector) {
        return mult(vector);
    }

    public Vector transposeNonTransposeMultiply(Vector vector) {
        double[] dArr = new double[columnCount()];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = this.vals[i] * this.vals[i] * vector.get(i);
        }
        return Vector.wrap(dArr);
    }

    public int columnCount() {
        return this.vals.length;
    }

    public double get(int i, int i2) {
        if (i != i2) {
            return 0.0d;
        }
        return this.vals[i];
    }

    public double put(int i, int i2, double d) {
        if (i != i2) {
            return 0.0d;
        }
        this.vals[i] = d;
        return d;
    }

    public int rowCount() {
        return this.vals.length;
    }

    public int used() {
        return this.vals.length;
    }

    public static DiagonalMatrix zeros(int i) {
        return fill(i, 0.0d);
    }

    public static DiagonalMatrix ones(int i) {
        return fill(i, 1.0d);
    }

    public static DiagonalMatrix fill(int i, double d) {
        DiagonalMatrix diagonalMatrix = new DiagonalMatrix(i);
        for (int i2 = 0; i2 < i; i2++) {
            diagonalMatrix.vals[i2] = d;
        }
        return diagonalMatrix;
    }

    public double[] getVals() {
        return this.vals;
    }

    public Matrix newInstance(int i, int i2) {
        return new DiagonalMatrix(i);
    }
}
