package org.openimaj.ml.classification;

import org.openimaj.util.function.Operation;
import org.openimaj.util.pair.ObjectFloatPair;
import org.openimaj.util.parallel.Parallel;

/* loaded from: input_file:org/openimaj/ml/classification/StumpClassifier.class */
public class StumpClassifier {
    public int dimension;
    public float threshold;
    public int sign;

    /* loaded from: input_file:org/openimaj/ml/classification/StumpClassifier$WeightedLearner.class */
    public static class WeightedLearner {
        public ObjectFloatPair<StumpClassifier> learn(final LabelledDataProvider labelledDataProvider, final float[] fArr) {
            final StumpClassifier stumpClassifier = new StumpClassifier();
            final float[] fArr2 = {Float.POSITIVE_INFINITY};
            final boolean[] classes = labelledDataProvider.getClasses();
            final int numInstances = labelledDataProvider.numInstances();
            float f = 0.0f;
            for (int i = 0; i < numInstances; i++) {
                f += fArr[i];
            }
            final float f2 = f;
            float f3 = 0.0f;
            for (int i2 = 0; i2 < numInstances; i2++) {
                f3 = (float) (f3 + (!classes[i2] ? fArr[i2] : 0.0d));
            }
            final float f4 = f3;
            Parallel.forRange(0, labelledDataProvider.numDimensions(), 1, new Operation<Parallel.IntRange>() { // from class: org.openimaj.ml.classification.StumpClassifier.WeightedLearner.1
                public void perform(Parallel.IntRange intRange) {
                    StumpClassifier stumpClassifier2 = new StumpClassifier();
                    stumpClassifier2.dimension = -1;
                    stumpClassifier2.threshold = Float.NaN;
                    stumpClassifier2.sign = 0;
                    float f5 = Float.POSITIVE_INFINITY;
                    int i3 = intRange.start;
                    while (true) {
                        int i4 = i3;
                        if (i4 >= intRange.stop) {
                            break;
                        }
                        float[] featureResponse = labelledDataProvider.getFeatureResponse(i4);
                        int[] sortedResponseIndices = labelledDataProvider.getSortedResponseIndices(i4);
                        float f6 = f4;
                        for (int i5 = 0; i5 < numInstances - 1; i5++) {
                            int i6 = sortedResponseIndices[i5];
                            f6 = classes[i6] ? f6 + fArr[i6] : f6 - fArr[i6];
                            if (featureResponse[sortedResponseIndices[i5]] != featureResponse[sortedResponseIndices[i5 + 1]]) {
                                float f7 = (featureResponse[sortedResponseIndices[i5]] + featureResponse[sortedResponseIndices[i5 + 1]]) / 2.0f;
                                if (f6 < f5) {
                                    f5 = f6;
                                    stumpClassifier2.dimension = i4;
                                    stumpClassifier2.threshold = f7;
                                    stumpClassifier2.sign = 1;
                                }
                                if (f2 - f6 < f5) {
                                    f5 = f2 - f6;
                                    stumpClassifier2.dimension = i4;
                                    stumpClassifier2.threshold = f7;
                                    stumpClassifier2.sign = -1;
                                }
                            }
                        }
                        i3 = i4 + intRange.incr;
                    }
                    synchronized (stumpClassifier) {
                        if (f5 < fArr2[0]) {
                            fArr2[0] = f5;
                            stumpClassifier.dimension = stumpClassifier2.dimension;
                            stumpClassifier.sign = stumpClassifier2.sign;
                            stumpClassifier.threshold = stumpClassifier2.threshold;
                        }
                    }
                }
            });
            return new ObjectFloatPair<>(stumpClassifier, fArr2[0]);
        }
    }

    public boolean classify(float[] fArr) {
        return ((fArr[this.dimension] > this.threshold ? 1 : (fArr[this.dimension] == this.threshold ? 0 : -1)) > 0 ? this.sign : -this.sign) == 1;
    }

    public boolean classify(float f) {
        return ((f > this.threshold ? 1 : (f == this.threshold ? 0 : -1)) > 0 ? this.sign : -this.sign) == 1;
    }
}
