package org.openimaj.ml.classification.boosting;

import java.util.ArrayList;
import java.util.List;
import org.openimaj.ml.classification.LabelledDataProvider;
import org.openimaj.ml.classification.StumpClassifier;
import org.openimaj.util.pair.ObjectFloatPair;

/* loaded from: input_file:org/openimaj/ml/classification/boosting/AdaBoost.class */
public class AdaBoost {
    StumpClassifier.WeightedLearner factory = new StumpClassifier.WeightedLearner();

    public List<ObjectFloatPair<StumpClassifier>> learn(LabelledDataProvider labelledDataProvider, int i) {
        float[] fArr = new float[labelledDataProvider.numInstances()];
        for (int i2 = 0; i2 < labelledDataProvider.numInstances(); i2++) {
            fArr[i2] = 1.0f / labelledDataProvider.numInstances();
        }
        boolean[] classes = labelledDataProvider.getClasses();
        ArrayList arrayList = new ArrayList();
        for (int i3 = 0; i3 < i; i3++) {
            System.out.println("Iteration: " + i3);
            ObjectFloatPair<StumpClassifier> learn = this.factory.learn(labelledDataProvider, fArr);
            boolean[] zArr = new boolean[labelledDataProvider.numInstances()];
            float[] featureResponse = labelledDataProvider.getFeatureResponse(((StumpClassifier) learn.first).dimension);
            double d = 0.0d;
            for (int i4 = 0; i4 < labelledDataProvider.numInstances(); i4++) {
                zArr[i4] = ((StumpClassifier) learn.first).classify(featureResponse[i4]);
                d += zArr[i4] != classes[i4] ? fArr[i4] : 0.0d;
            }
            if (d >= 0.5d) {
                break;
            }
            float log = (float) (0.5d * Math.log((1.0d - d) / d));
            float f = 0.0f;
            for (int i5 = 0; i5 < labelledDataProvider.numInstances(); i5++) {
                fArr[i5] = (float) (fArr[r1] * Math.exp((-log) * (classes[i5] ? 1 : -1) * (zArr[i5] ? 1 : -1)));
                f += fArr[i5];
            }
            for (int i6 = 0; i6 < labelledDataProvider.numInstances(); i6++) {
                int i7 = i6;
                fArr[i7] = fArr[i7] / f;
            }
            arrayList.add(new ObjectFloatPair(learn.first, log));
            if (d == 0.0d) {
                break;
            }
        }
        return arrayList;
    }

    public void printClassificationQuality(LabelledDataProvider labelledDataProvider, List<ObjectFloatPair<StumpClassifier>> list, float f) {
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        int i4 = 0;
        int numInstances = labelledDataProvider.numInstances();
        boolean[] classes = labelledDataProvider.getClasses();
        for (int i5 = 0; i5 < numInstances; i5++) {
            boolean classify = classify(labelledDataProvider.getInstanceFeature(i5), list, f);
            if (classes[i5]) {
                if (classify) {
                    i++;
                } else {
                    i2++;
                }
            } else if (classify) {
                i4++;
            } else {
                i3++;
            }
        }
        System.out.format("TP: %d\tFN: %d\tFP: %d\tTN: %d\n", Integer.valueOf(i), Integer.valueOf(i2), Integer.valueOf(i4), Integer.valueOf(i3));
        System.out.format("FPR: %2.2f\tTPR: %2.2f\n", Float.valueOf(i4 / (i4 + i3)), Float.valueOf(i / (i + i2)));
    }

    public static boolean classify(float[] fArr, List<ObjectFloatPair<StumpClassifier>> list) {
        double d = 0.0d;
        for (int i = 0; i < list.size(); i++) {
            d += list.get(i).second * (((StumpClassifier) list.get(i).first).classify(fArr) ? 1 : -1);
        }
        return d > 0.0d;
    }

    public static boolean classify(float[] fArr, List<ObjectFloatPair<StumpClassifier>> list, float f) {
        double d = 0.0d;
        for (int i = 0; i < list.size(); i++) {
            d += list.get(i).second * (((StumpClassifier) list.get(i).first).classify(fArr) ? 1 : -1);
        }
        return d > ((double) f);
    }
}
