package org.openimaj.experiment.evaluation.classification;

import gnu.trove.map.hash.TObjectDoubleHashMap;
import gnu.trove.procedure.TObjectDoubleProcedure;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.Set;
import org.openimaj.util.pair.ObjectDoublePair;

/* loaded from: input_file:org/openimaj/experiment/evaluation/classification/BasicClassificationResult.class */
public class BasicClassificationResult<CLASS> implements ClassificationResult<CLASS> {
    private final TObjectDoubleHashMap<CLASS> data;
    private double threshold;

    public BasicClassificationResult() {
        this.data = new TObjectDoubleHashMap<>();
        this.threshold = 0.0d;
    }

    public BasicClassificationResult(double d) {
        this.data = new TObjectDoubleHashMap<>();
        this.threshold = 0.0d;
        this.threshold = d;
    }

    public void put(CLASS r6, double d) {
        this.data.put(r6, d);
    }

    @Override // org.openimaj.experiment.evaluation.classification.ClassificationResult
    public double getConfidence(CLASS r4) {
        return this.data.get(r4);
    }

    @Override // org.openimaj.experiment.evaluation.classification.ClassificationResult
    public Set<CLASS> getPredictedClasses() {
        final ArrayList arrayList = new ArrayList();
        this.data.forEachEntry(new TObjectDoubleProcedure<CLASS>() { // from class: org.openimaj.experiment.evaluation.classification.BasicClassificationResult.1
            public boolean execute(CLASS r8, double d) {
                if (d <= BasicClassificationResult.this.threshold) {
                    return true;
                }
                arrayList.add(new ObjectDoublePair(r8, d));
                return true;
            }
        });
        Collections.sort(arrayList, new Comparator<ObjectDoublePair<CLASS>>() { // from class: org.openimaj.experiment.evaluation.classification.BasicClassificationResult.2
            @Override // java.util.Comparator
            public int compare(ObjectDoublePair<CLASS> objectDoublePair, ObjectDoublePair<CLASS> objectDoublePair2) {
                return (-1) * Double.compare(objectDoublePair.second, objectDoublePair2.second);
            }
        });
        LinkedHashSet linkedHashSet = new LinkedHashSet(arrayList.size());
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            linkedHashSet.add(((ObjectDoublePair) it.next()).first);
        }
        return linkedHashSet;
    }
}
