package org.openimaj.workinprogress.featlearn.cifarexps;

import de.bwaldvogel.liblinear.SolverType;
import java.io.IOException;
import java.util.Iterator;
import org.openimaj.data.dataset.GroupedDataset;
import org.openimaj.data.dataset.ListDataset;
import org.openimaj.feature.DoubleFV;
import org.openimaj.feature.FeatureExtractor;
import org.openimaj.image.MBFImage;
import org.openimaj.image.annotation.evaluation.datasets.CIFAR10Dataset;
import org.openimaj.math.statistics.normalisation.ZScore;
import org.openimaj.ml.annotation.AnnotatedObject;
import org.openimaj.ml.annotation.linear.LiblinearAnnotator;
import org.openimaj.util.function.Operation;
import org.openimaj.util.parallel.Parallel;
import org.openimaj.workinprogress.featlearn.RandomPatchSampler;

/* loaded from: input_file:org/openimaj/workinprogress/featlearn/cifarexps/CIFARExperimentFramework.class */
public abstract class CIFARExperimentFramework {
    protected final int patchSize = 6;
    protected final int numPatches = 400000;
    protected final int C = 1;

    protected abstract void learnFeatures(double[][] dArr);

    protected abstract double[] extractFeatures(MBFImage mBFImage);

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v19, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r0v4, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r0v40, types: [double[], double[][]] */
    public double run() throws IOException {
        final GroupedDataset trainingImages = CIFAR10Dataset.getTrainingImages(CIFAR10Dataset.MBFIMAGE_READER);
        RandomPatchSampler randomPatchSampler = new RandomPatchSampler(trainingImages, 6, 6, 400000);
        ?? r0 = new double[400000];
        int i = 0;
        Iterator it = randomPatchSampler.iterator();
        while (it.hasNext()) {
            int i2 = i;
            i++;
            r0[i2] = ((MBFImage) it.next()).getDoublePixelVector();
            if (i % 10000 == 0) {
                System.out.format("Extracting patch %d / %d\n", Integer.valueOf(i), 400000);
            }
        }
        learnFeatures(r0);
        final MBFImage[] mBFImageArr = new MBFImage[trainingImages.numInstances()];
        String[] strArr = new String[trainingImages.numInstances()];
        final ?? r02 = new double[trainingImages.numInstances()];
        int i3 = 0;
        for (String str : trainingImages.getGroups()) {
            Iterator it2 = ((ListDataset) trainingImages.get(str)).iterator();
            while (it2.hasNext()) {
                mBFImageArr[i3] = (MBFImage) it2.next();
                strArr[i3] = str;
                i3++;
            }
        }
        Parallel.forRange(0, mBFImageArr.length, 1, new Operation<Parallel.IntRange>() { // from class: org.openimaj.workinprogress.featlearn.cifarexps.CIFARExperimentFramework.1
            volatile int count = 0;

            public void perform(Parallel.IntRange intRange) {
                for (int i4 = intRange.start; i4 < intRange.stop; i4++) {
                    if (this.count % 100 == 0) {
                        System.out.format("Extracting features %d / %d\n", Integer.valueOf(this.count), Integer.valueOf(trainingImages.numInstances()));
                    }
                    r02[i4] = CIFARExperimentFramework.this.extractFeatures(mBFImageArr[i4]);
                    this.count++;
                }
            }
        });
        ZScore zScore = new ZScore(0.01d);
        zScore.train((double[][]) r02);
        double[][] normalise = zScore.normalise((double[][]) r02);
        LiblinearAnnotator liblinearAnnotator = new LiblinearAnnotator(new FeatureExtractor<DoubleFV, double[]>() { // from class: org.openimaj.workinprogress.featlearn.cifarexps.CIFARExperimentFramework.2
            public DoubleFV extractFeature(double[] dArr) {
                return new DoubleFV(dArr);
            }
        }, LiblinearAnnotator.Mode.MULTICLASS, SolverType.L2R_L2LOSS_SVC_DUAL, 1.0d, 0.1d, 1.0d, true);
        liblinearAnnotator.train(AnnotatedObject.createList(normalise, strArr));
        GroupedDataset testImages = CIFAR10Dataset.getTestImages(CIFAR10Dataset.MBFIMAGE_READER);
        String[] strArr2 = new String[testImages.numInstances()];
        ?? r03 = new double[testImages.numInstances()];
        int i4 = 0;
        for (String str2 : testImages.getGroups()) {
            Iterator it3 = ((ListDataset) testImages.get(str2)).iterator();
            while (it3.hasNext()) {
                r03[i4] = extractFeatures((MBFImage) it3.next());
                strArr2[i4] = str2;
                i4++;
            }
        }
        double[][] normalise2 = zScore.normalise((double[][]) r03);
        double d = 0.0d;
        double d2 = 0.0d;
        for (int i5 = 0; i5 < normalise2.length; i5++) {
            if (((String) liblinearAnnotator.classify(normalise2[i5]).getPredictedClasses().iterator().next()).equals(strArr2[i5])) {
                d += 1.0d;
            } else {
                d2 += 1.0d;
            }
        }
        double d3 = d / (d + d2);
        System.out.println("Test accuracy " + d3);
        return d3;
    }
}
