package org.openimaj.image.objectdetection.haar.training;

import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.openimaj.image.ImageUtilities;
import org.openimaj.image.analysis.algorithm.SummedSqTiltAreaTable;
import org.openimaj.image.objectdetection.haar.HaarFeature;
import org.openimaj.image.objectdetection.haar.HaarFeatureClassifier;
import org.openimaj.image.objectdetection.haar.Stage;
import org.openimaj.image.objectdetection.haar.StageTreeClassifier;
import org.openimaj.image.objectdetection.haar.ValueClassifier;
import org.openimaj.io.IOUtils;
import org.openimaj.ml.classification.StumpClassifier;
import org.openimaj.ml.classification.boosting.AdaBoost;
import org.openimaj.util.pair.ObjectFloatPair;

/* loaded from: input_file:org/openimaj/image/objectdetection/haar/training/Testing.class */
public class Testing {
    List<HaarFeature> features;
    List<SummedSqTiltAreaTable> positive = new ArrayList();
    List<SummedSqTiltAreaTable> negative = new ArrayList();

    void createFeatures(int i, int i2) {
        this.features = HaarFeatureType.generateFeatures(i, i2, HaarFeatureType.CORE);
        float f = 1.0f / ((i - 2) * (i2 - 2));
        Iterator<HaarFeature> it = this.features.iterator();
        while (it.hasNext()) {
            it.next().setScale(1.0f, f);
        }
    }

    void loadImage(File file, List<SummedSqTiltAreaTable> list, boolean z) throws IOException {
        list.add(new SummedSqTiltAreaTable(ImageUtilities.readF(file), false));
    }

    void loadPositive(boolean z) throws IOException {
        for (File file : new File("/Users/jsh2/Data/cbcl-faces/train/face").listFiles()) {
            if (file.getName().endsWith(".pgm")) {
                loadImage(file, this.positive, z);
            }
        }
    }

    void loadNegative(boolean z) throws IOException {
        for (File file : new File("/Users/jsh2/Data/cbcl-faces/train/non-face").listFiles()) {
            if (file.getName().endsWith(".pgm")) {
                loadImage(file, this.negative, z);
            }
        }
    }

    void perform() throws IOException {
        System.out.println("Creating feature set");
        createFeatures(19, 19);
        System.out.println("Loading positive images and computing SATs");
        loadPositive(false);
        System.out.println("Loading negative images and computing SATs");
        loadNegative(false);
        System.out.println("+ve: " + this.positive.size());
        System.out.println("-ve: " + this.negative.size());
        System.out.println("features: " + this.features.size());
        System.out.println("Computing cached feature sets");
        CachedTrainingData cachedTrainingData = new CachedTrainingData(this.positive, this.negative, this.features);
        System.out.println("Starting Training");
        AdaBoost adaBoost = new AdaBoost();
        List<ObjectFloatPair<StumpClassifier>> learn = adaBoost.learn(cachedTrainingData, 500);
        System.out.println("Training complete. Ensemble has " + learn.size() + " classifiers.");
        float f = 3.0f;
        while (true) {
            float f2 = f;
            if (f2 < -3.0f) {
                break;
            }
            System.out.println("Threshold = " + f2);
            adaBoost.printClassificationQuality(cachedTrainingData, learn, f2);
            f = f2 - 0.25f;
        }
        StageTreeClassifier stageTreeClassifier = new StageTreeClassifier(19, 19, "test cascade", false, createStage(learn));
        stageTreeClassifier.setScale(1.0f);
        for (int i = 0; i < this.positive.size(); i++) {
            if ((stageTreeClassifier.classify(this.positive.get(i), 0, 0) == 1) != AdaBoost.classify(cachedTrainingData.getInstanceFeature(i), learn)) {
                System.out.println("ERROR");
            }
        }
        for (int i2 = 0; i2 < this.negative.size(); i2++) {
            if ((stageTreeClassifier.classify(this.negative.get(i2), 0, 0) == 1) != AdaBoost.classify(cachedTrainingData.getInstanceFeature(i2 + this.positive.size()), learn)) {
                System.out.println(stageTreeClassifier.classify(this.negative.get(i2), 0, 0) + " " + AdaBoost.classify(cachedTrainingData.getInstanceFeature(i2 + this.positive.size()), learn));
                System.out.println("ERROR2");
            }
        }
        ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream(new File("test-classifier.bin")));
        IOUtils.write(stageTreeClassifier, objectOutputStream);
        objectOutputStream.close();
    }

    private Stage createStage(List<ObjectFloatPair<StumpClassifier>> list) {
        HaarFeatureClassifier[] haarFeatureClassifierArr = new HaarFeatureClassifier[list.size()];
        for (int i = 0; i < haarFeatureClassifierArr.length; i++) {
            ObjectFloatPair<StumpClassifier> objectFloatPair = list.get(i);
            StumpClassifier stumpClassifier = (StumpClassifier) objectFloatPair.first;
            float f = objectFloatPair.second;
            float f2 = stumpClassifier.threshold;
            float f3 = stumpClassifier.sign > 0 ? -f : f;
            haarFeatureClassifierArr[i] = new HaarFeatureClassifier(this.features.get(stumpClassifier.dimension), f2, new ValueClassifier(f3), new ValueClassifier(-f3));
        }
        return new Stage(0.0f, haarFeatureClassifierArr, (Stage) null, (Stage) null);
    }

    public static void main(String[] strArr) throws IOException {
        new Testing().perform();
    }
}
