package org.openimaj.ml.linear.learner.perceptron;

import ch.akuhn.matrix.DenseMatrix;
import ch.akuhn.matrix.DenseVector;
import ch.akuhn.matrix.Matrix;
import ch.akuhn.matrix.Vector;
import gnu.trove.list.array.TDoubleArrayList;
import gnu.trove.list.array.TIntArrayList;
import java.util.ArrayList;
import java.util.List;
import org.openimaj.citation.annotation.Reference;
import org.openimaj.citation.annotation.ReferenceType;
import org.openimaj.citation.annotation.References;
import org.openimaj.math.matrix.MatlibMatrixUtils;
import org.openimaj.ml.linear.kernel.Kernel;
import org.openimaj.ml.linear.learner.OnlineLearner;
import org.openimaj.util.pair.IndependentPair;

@References(references = {@Reference(type = ReferenceType.Article, author = {"Francesco Orabona", "Claudio Castellini", "Barbara Caputo", "Luo Jie", "Giulio Sandini"}, title = "On-line independent support vector machines", year = "2010", journal = "Pattern Recognition", pages = {"1402", "1412"}, number = "4", volume = "43")})
/* loaded from: input_file:org/openimaj/ml/linear/learner/perceptron/OISVM.class */
public class OISVM implements OnlineLearner<double[], PerceptronClass> {
    private static final int DEFAULT_NEWTON_ITER = 20;
    private Kernel<double[]> kernel;
    private double eta;
    private Vector beta;
    protected List<double[]> supports = new ArrayList();
    protected TIntArrayList supportIndex = new TIntArrayList();
    private List<PerceptronClass> expected = new ArrayList();
    int newtonMaxIter = DEFAULT_NEWTON_ITER;
    double C = 1.0d;
    private List<double[]> nonSupports = new ArrayList();
    private Matrix K = DenseMatrix.dense(0, 0);
    private Matrix Kinv = DenseMatrix.dense(0, 0);

    public OISVM(Kernel<double[]> kernel, double d) {
        this.kernel = kernel;
        this.eta = d;
    }

    @Override // org.openimaj.ml.linear.learner.OnlineLearner
    public void process(double[] dArr, PerceptronClass perceptronClass) {
        double doubleValue = ((Double) this.kernel.apply(IndependentPair.pair(dArr, dArr))).doubleValue();
        Vector calculatekt = calculatekt(dArr);
        Vector times = calculatekt.times(perceptronClass.v());
        Vector mult = this.Kinv.mult(times);
        double dot = doubleValue - mult.dot(times);
        if (dot > this.eta) {
            updateSupports(dArr, perceptronClass, times, doubleValue, mult, dot);
        } else {
            updateNonSupports(dArr, perceptronClass, calculatekt, mult);
        }
        if (times.dot(this.beta) >= 1.0d) {
            return;
        }
        int i = 0;
        TIntArrayList tIntArrayList = new TIntArrayList();
        while (true) {
            TIntArrayList tIntArrayList2 = tIntArrayList;
            int i2 = i;
            i++;
            if (i2 >= this.newtonMaxIter) {
                return;
            }
            TIntArrayList newtonStep = newtonStep(tIntArrayList2);
            if (newtonStep.equals(tIntArrayList2)) {
                return;
            } else {
                tIntArrayList = newtonStep;
            }
        }
    }

    private void updateSupports(double[] dArr, PerceptronClass perceptronClass, Vector vector, double d, Vector vector2, double d2) {
        this.supports.add(dArr);
        this.expected.add(perceptronClass);
        this.supportIndex.add(this.K.columnCount() - 1);
        if (this.supports.size() <= 1) {
            init();
            return;
        }
        updateKinv(vector2, d2);
        updateK(vector, d);
        updateBeta();
    }

    private void updateNonSupports(double[] dArr, PerceptronClass perceptronClass, Vector vector, Vector vector2) {
        this.nonSupports.add(vector2.unwrap());
        MatlibMatrixUtils.appendColumn(this.K, vector);
    }

    private TIntArrayList newtonStep(TIntArrayList tIntArrayList) {
        TIntArrayList tIntArrayList2 = new TIntArrayList(this.K.rowCount());
        TDoubleArrayList tDoubleArrayList = new TDoubleArrayList(this.K.rowCount());
        Vector mult = this.K.mult(this.beta);
        for (int i = 0; i < this.K.rowCount(); i++) {
            int v = this.expected.get(i).v();
            if (1.0d - (mult.get(i) * v) > 0.0d) {
                tIntArrayList2.add(i);
                tDoubleArrayList.add(v);
            }
        }
        if (tIntArrayList.equals(tIntArrayList2)) {
            return tIntArrayList2;
        }
        Matrix subMatrix = MatlibMatrixUtils.subMatrix(this.K, 0, this.K.rowCount(), tIntArrayList2);
        Vector wrap = DenseVector.wrap(tDoubleArrayList.toArray());
        MatlibMatrixUtils.minusInplace(wrap, subMatrix.mult(this.beta));
        Vector minus = MatlibMatrixUtils.minus(mult, subMatrix.mult(wrap));
        DenseMatrix denseMatrix = new DenseMatrix(this.K.rowCount(), this.K.columnCount());
        MatlibMatrixUtils.dotProductTranspose(subMatrix, subMatrix, denseMatrix);
        MatlibMatrixUtils.scaleInplace(denseMatrix, this.C);
        MatlibMatrixUtils.plusInplace(denseMatrix, this.K);
        MatlibMatrixUtils.minusInplace(this.beta, MatlibMatrixUtils.fromJama(MatlibMatrixUtils.toJama(denseMatrix).inverse()).mult(minus));
        return tIntArrayList2;
    }

    @Override // org.openimaj.ml.linear.learner.OnlineLearner
    public PerceptronClass predict(double[] dArr) {
        return PerceptronClass.fromSign(Math.signum(calculatekt(dArr).dot(this.beta)));
    }

    private void init() {
        double[] dArr = this.supports.get(0);
        this.K = DenseMatrix.dense(1, 1);
        this.Kinv = DenseMatrix.dense(1, 1);
        double doubleValue = ((Double) this.kernel.apply(IndependentPair.pair(dArr, dArr))).doubleValue();
        this.Kinv.put(0, 0, 1.0d / doubleValue);
        this.K.put(0, 0, doubleValue);
        this.beta = DenseVector.dense(1);
    }

    private void updateK(Vector vector, double d) {
        Vector dense = DenseVector.dense(this.K.columnCount());
        dense.put(this.K.columnCount() - 1, d);
        this.K = MatlibMatrixUtils.appendRow(this.K, dense);
    }

    private void updateBeta() {
        Vector dense = DenseVector.dense(this.beta.size() + 1);
        MatlibMatrixUtils.setSubVector(dense, 0, this.beta);
        this.beta = dense;
    }

    private void updateKinv(Vector vector, double d) {
        Matrix dense = DenseMatrix.dense(vector.size() + 1, 1);
        MatlibMatrixUtils.setSubVector(dense.column(0), 0, vector);
        dense.column(0).put(vector.size(), -1.0d);
        DenseMatrix denseMatrix = new DenseMatrix(this.Kinv.rowCount() + 1, this.Kinv.columnCount() + 1);
        MatlibMatrixUtils.setSubMatrix(denseMatrix, 0, 0, this.Kinv);
        Matrix newInstance = denseMatrix.newInstance();
        MatlibMatrixUtils.dotProductTranspose(dense, dense, newInstance);
        MatlibMatrixUtils.scaleInplace(newInstance, 1.0d / d);
        MatlibMatrixUtils.plusInplace(denseMatrix, newInstance);
        this.Kinv = denseMatrix;
    }

    private Vector calculatekt(double[] dArr) {
        Vector dense = Vector.dense(this.supports.size());
        for (int i = 0; i < this.supports.size(); i++) {
            dense.put(i, ((Double) this.kernel.apply(IndependentPair.pair(dArr, this.supports.get(i)))).doubleValue());
        }
        return dense;
    }
}
