package org.openimaj.workinprogress.sgdsvm;

import gnu.trove.list.array.TDoubleArrayList;
import java.util.List;
import org.apache.commons.math.random.MersenneTwister;
import org.openimaj.feature.FloatFV;
import org.openimaj.feature.FloatFVComparison;
import org.openimaj.util.array.ArrayUtils;
import org.openimaj.util.array.SparseFloatArray;
import org.openimaj.util.array.SparseHashedFloatArray;

/* loaded from: input_file:org/openimaj/workinprogress/sgdsvm/SvmSgd.class */
public class SvmSgd implements Cloneable {
    Loss LOSS;
    boolean BIAS;
    boolean REGULARIZED_BIAS;
    public double lambda;
    public double eta0;
    FloatFV w;
    double wDivisor;
    double wBias;
    double t;
    static final /* synthetic */ boolean $assertionsDisabled;

    public SvmSgd(int i, double d) {
        this(i, d, 0.0d);
    }

    public SvmSgd(int i, double d, double d2) {
        this.LOSS = LossFunctions.HingeLoss;
        this.BIAS = true;
        this.REGULARIZED_BIAS = false;
        this.lambda = d;
        this.eta0 = d2;
        this.w = new FloatFV(i);
        this.wDivisor = 1.0d;
        this.wBias = 0.0d;
        this.t = 0.0d;
    }

    private double dot(FloatFV floatFV, SparseFloatArray sparseFloatArray) {
        double d = 0.0d;
        for (SparseFloatArray.Entry entry : sparseFloatArray.entries()) {
            d += entry.value * ((float[]) floatFV.values)[entry.index];
        }
        return d;
    }

    private double dot(FloatFV floatFV, FloatFV floatFV2) {
        return FloatFVComparison.INNER_PRODUCT.compare(floatFV, floatFV2);
    }

    private void add(FloatFV floatFV, SparseFloatArray sparseFloatArray, double d) {
        for (SparseFloatArray.Entry entry : sparseFloatArray.entries()) {
            ((float[]) floatFV.values)[entry.index] = (float) (r0[r1] + (entry.value * d));
        }
    }

    public void renorm() {
        if (this.wDivisor != 1.0d) {
            ArrayUtils.multiply((float[]) this.w.values, (float) (1.0d / this.wDivisor));
            this.wDivisor = 1.0d;
        }
    }

    public double wnorm() {
        double dot = (dot(this.w, this.w) / this.wDivisor) / this.wDivisor;
        if (this.REGULARIZED_BIAS) {
            dot += this.wBias * this.wBias;
        }
        return dot;
    }

    public double testOne(SparseFloatArray sparseFloatArray, double d, double[] dArr, double[] dArr2) {
        double dot = (dot(this.w, sparseFloatArray) / this.wDivisor) + this.wBias;
        if (dArr != null) {
            dArr[0] = dArr[0] + this.LOSS.loss(dot, d);
        }
        if (dArr2 != null) {
            dArr2[0] = dArr2[0] + (dot * d <= 0.0d ? 1.0d : 0.0d);
        }
        return dot;
    }

    public void trainOne(SparseFloatArray sparseFloatArray, double d, double d2) {
        double dot = (dot(this.w, sparseFloatArray) / this.wDivisor) + this.wBias;
        this.wDivisor /= 1.0d - (d2 * this.lambda);
        if (this.wDivisor > 100000.0d) {
            renorm();
        }
        double dloss = this.LOSS.dloss(dot, d);
        if (dloss != 0.0d) {
            add(this.w, sparseFloatArray, d2 * dloss * this.wDivisor);
        }
        if (this.BIAS) {
            double d3 = d2 * 0.01d;
            if (this.REGULARIZED_BIAS) {
                this.wBias *= 1.0d - (d3 * this.lambda);
            }
            this.wBias += d3 * dloss;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public SvmSgd m52clone() {
        try {
            SvmSgd svmSgd = (SvmSgd) super.clone();
            svmSgd.w = svmSgd.w.clone();
            return svmSgd;
        } catch (CloneNotSupportedException e) {
            throw new RuntimeException(e);
        }
    }

    public void train(int i, int i2, SparseFloatArray[] sparseFloatArrayArr, double[] dArr) {
        System.out.println("Training on [" + i + ", " + i2 + "].");
        if (!$assertionsDisabled && i > i2) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && this.eta0 <= 0.0d) {
            throw new AssertionError();
        }
        for (int i3 = i; i3 <= i2; i3++) {
            trainOne(sparseFloatArrayArr[i3], dArr[i3], this.eta0 / (1.0d + ((this.lambda * this.eta0) * this.t)));
            this.t += 1.0d;
        }
        System.out.format("wNorm=%.6f", Double.valueOf(wnorm()));
        if (this.BIAS) {
            System.out.format(" wBias=%.6f", Double.valueOf(this.wBias));
        }
        System.out.println();
    }

    public void train(int i, int i2, List<SparseFloatArray> list, TDoubleArrayList tDoubleArrayList) {
        System.out.println("Training on [" + i + ", " + i2 + "].");
        if (!$assertionsDisabled && i > i2) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && this.eta0 <= 0.0d) {
            throw new AssertionError();
        }
        for (int i3 = i; i3 <= i2; i3++) {
            trainOne(list.get(i3), tDoubleArrayList.get(i3), this.eta0 / (1.0d + ((this.lambda * this.eta0) * this.t)));
            this.t += 1.0d;
        }
        System.out.format("wNorm=%.6f", Double.valueOf(wnorm()));
        if (this.BIAS) {
            System.out.format(" wBias=%.6f", Double.valueOf(this.wBias));
        }
        System.out.println();
    }

    public void test(int i, int i2, SparseFloatArray[] sparseFloatArrayArr, double[] dArr, String str) {
        System.out.println(str + "Testing on [" + i + ", " + i2 + "].");
        if (!$assertionsDisabled && i > i2) {
            throw new AssertionError();
        }
        double[] dArr2 = {0.0d};
        double[] dArr3 = {0.0d};
        for (int i3 = i; i3 <= i2; i3++) {
            testOne(sparseFloatArrayArr[i3], dArr[i3], dArr3, dArr2);
        }
        dArr2[0] = dArr2[0] / ((i2 - i) + 1);
        dArr3[0] = dArr3[0] / ((i2 - i) + 1);
        System.out.println(str + "Loss=" + dArr3[0] + " Cost=" + (dArr3[0] + (0.5d * this.lambda * wnorm())) + " Misclassification=" + String.format("%2.4f", Double.valueOf(100.0d * dArr2[0])) + "%");
    }

    public void test(int i, int i2, List<SparseFloatArray> list, TDoubleArrayList tDoubleArrayList, String str) {
        System.out.println(str + "Testing on [" + i + ", " + i2 + "].");
        if (!$assertionsDisabled && i > i2) {
            throw new AssertionError();
        }
        double[] dArr = {0.0d};
        double[] dArr2 = {0.0d};
        for (int i3 = i; i3 <= i2; i3++) {
            testOne(list.get(i3), tDoubleArrayList.get(i3), dArr2, dArr);
        }
        dArr[0] = dArr[0] / ((i2 - i) + 1);
        dArr2[0] = dArr2[0] / ((i2 - i) + 1);
        System.out.println(str + "Loss=" + dArr2[0] + " Cost=" + (dArr2[0] + (0.5d * this.lambda * wnorm())) + " Misclassification=" + String.format("%2.4f", Double.valueOf(100.0d * dArr[0])) + "%");
    }

    public double evaluateEta(int i, int i2, SparseFloatArray[] sparseFloatArrayArr, double[] dArr, double d) {
        SvmSgd m52clone = m52clone();
        if (!$assertionsDisabled && i > i2) {
            throw new AssertionError();
        }
        for (int i3 = i; i3 <= i2; i3++) {
            m52clone.trainOne(sparseFloatArrayArr[i3], dArr[i3], d);
        }
        double[] dArr2 = {0.0d};
        for (int i4 = i; i4 <= i2; i4++) {
            m52clone.testOne(sparseFloatArrayArr[i4], dArr[i4], dArr2, null);
        }
        dArr2[0] = dArr2[0] / ((i2 - i) + 1);
        double wnorm = dArr2[0] + (0.5d * this.lambda * m52clone.wnorm());
        System.out.println("Trying eta=" + d + " yields cost " + wnorm);
        return wnorm;
    }

    public double evaluateEta(int i, int i2, List<SparseFloatArray> list, TDoubleArrayList tDoubleArrayList, double d) {
        SvmSgd m52clone = m52clone();
        if (!$assertionsDisabled && i > i2) {
            throw new AssertionError();
        }
        for (int i3 = i; i3 <= i2; i3++) {
            m52clone.trainOne(list.get(i3), tDoubleArrayList.get(i3), d);
        }
        double[] dArr = {0.0d};
        for (int i4 = i; i4 <= i2; i4++) {
            m52clone.testOne(list.get(i4), tDoubleArrayList.get(i4), dArr, null);
        }
        dArr[0] = dArr[0] / ((i2 - i) + 1);
        double wnorm = dArr[0] + (0.5d * this.lambda * m52clone.wnorm());
        System.out.println("Trying eta=" + d + " yields cost " + wnorm);
        return wnorm;
    }

    public void determineEta0(int i, int i2, SparseFloatArray[] sparseFloatArrayArr, double[] dArr) {
        double d = 1.0d;
        double evaluateEta = evaluateEta(i, i2, sparseFloatArrayArr, dArr, 1.0d);
        double d2 = 1.0d * 2.0d;
        double evaluateEta2 = evaluateEta(i, i2, sparseFloatArrayArr, dArr, d2);
        if (evaluateEta < evaluateEta2) {
            while (evaluateEta < evaluateEta2) {
                evaluateEta2 = evaluateEta;
                d /= 2.0d;
                evaluateEta = evaluateEta(i, i2, sparseFloatArrayArr, dArr, d);
            }
        } else if (evaluateEta2 < evaluateEta) {
            while (evaluateEta2 < evaluateEta) {
                d = d2;
                evaluateEta = evaluateEta2;
                d2 = d * 2.0d;
                evaluateEta2 = evaluateEta(i, i2, sparseFloatArrayArr, dArr, d2);
            }
        }
        this.eta0 = d;
        System.out.println("# Using eta0=" + this.eta0 + "\n");
    }

    public void determineEta0(int i, int i2, List<SparseFloatArray> list, TDoubleArrayList tDoubleArrayList) {
        double d = 1.0d;
        double evaluateEta = evaluateEta(i, i2, list, tDoubleArrayList, 1.0d);
        double d2 = 1.0d * 2.0d;
        double evaluateEta2 = evaluateEta(i, i2, list, tDoubleArrayList, d2);
        if (evaluateEta < evaluateEta2) {
            while (evaluateEta < evaluateEta2) {
                evaluateEta2 = evaluateEta;
                d /= 2.0d;
                evaluateEta = evaluateEta(i, i2, list, tDoubleArrayList, d);
            }
        } else if (evaluateEta2 < evaluateEta) {
            while (evaluateEta2 < evaluateEta) {
                d = d2;
                evaluateEta = evaluateEta2;
                d2 = d * 2.0d;
                evaluateEta2 = evaluateEta(i, i2, list, tDoubleArrayList, d2);
            }
        }
        this.eta0 = d;
        System.out.println("# Using eta0=" + this.eta0 + "\n");
    }

    public static void main(String[] strArr) {
        MersenneTwister mersenneTwister = new MersenneTwister();
        SparseFloatArray[] sparseFloatArrayArr = new SparseFloatArray[10000];
        double[] dArr = new double[sparseFloatArrayArr.length];
        for (int i = 0; i < sparseFloatArrayArr.length; i++) {
            sparseFloatArrayArr[i] = new SparseHashedFloatArray(2);
            if (i < sparseFloatArrayArr.length / 2) {
                sparseFloatArrayArr[i].set(0, (float) (mersenneTwister.nextGaussian() - 2.0d));
                sparseFloatArrayArr[i].set(1, (float) (mersenneTwister.nextGaussian() - 2.0d));
                dArr[i] = -1.0d;
            } else {
                sparseFloatArrayArr[i].set(0, (float) (mersenneTwister.nextGaussian() + 2.0d));
                sparseFloatArrayArr[i].set(1, (float) (mersenneTwister.nextGaussian() + 2.0d));
                dArr[i] = 1.0d;
            }
            System.out.println(sparseFloatArrayArr[i].values()[0] + " " + dArr[i]);
        }
        SvmSgd svmSgd = new SvmSgd(2, 1.0E-5d);
        svmSgd.BIAS = true;
        svmSgd.REGULARIZED_BIAS = false;
        svmSgd.determineEta0(0, sparseFloatArrayArr.length - 1, sparseFloatArrayArr, dArr);
        for (int i2 = 0; i2 < 10; i2++) {
            System.out.println();
            svmSgd.train(0, sparseFloatArrayArr.length - 1, sparseFloatArrayArr, dArr);
            svmSgd.test(0, sparseFloatArrayArr.length - 1, sparseFloatArrayArr, dArr, "training ");
            System.out.println(svmSgd.w);
            System.out.println(svmSgd.wBias);
            System.out.println(svmSgd.wDivisor);
        }
    }

    static {
        $assertionsDisabled = !SvmSgd.class.desiredAssertionStatus();
    }
}
