package org.openimaj.ml.annotation.linear;

import gov.sandia.cognition.learning.algorithm.svm.PrimalEstimatedSubGradient;
import gov.sandia.cognition.learning.data.DefaultInputOutputPair;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.categorization.LinearBinaryCategorizer;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.Vectorizable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.openimaj.feature.FeatureExtractor;
import org.openimaj.feature.FeatureVector;
import org.openimaj.ml.annotation.Annotated;
import org.openimaj.ml.annotation.BatchAnnotator;
import org.openimaj.ml.annotation.ScoredAnnotation;
import org.openimaj.ml.annotation.utils.AnnotatedListHelper;

/* loaded from: input_file:org/openimaj/ml/annotation/linear/LinearSVMAnnotator.class */
public class LinearSVMAnnotator<OBJECT, ANNOTATION, EXTRACTOR extends FeatureExtractor<? extends FeatureVector, OBJECT>> extends BatchAnnotator<OBJECT, ANNOTATION, EXTRACTOR> {
    private final Map<ANNOTATION, LinearBinaryCategorizer> classifiers;
    private Set<ANNOTATION> annotations;
    private ANNOTATION negativeClass;

    public LinearSVMAnnotator(EXTRACTOR extractor, ANNOTATION annotation) {
        super(extractor);
        this.classifiers = new HashMap();
        this.negativeClass = annotation;
    }

    public LinearSVMAnnotator(EXTRACTOR extractor) {
        this(extractor, null);
    }

    @Override // org.openimaj.ml.training.BatchTrainer
    public void train(List<? extends Annotated<OBJECT, ANNOTATION>> list) {
        AnnotatedListHelper annotatedListHelper = new AnnotatedListHelper(list);
        this.annotations = annotatedListHelper.getAnnotations();
        for (ANNOTATION annotation : this.annotations) {
            PrimalEstimatedSubGradient primalEstimatedSubGradient = new PrimalEstimatedSubGradient();
            primalEstimatedSubGradient.learn(convert(annotatedListHelper.extractFeatures(annotation, this.extractor), annotatedListHelper.extractFeaturesExclude(annotation, this.extractor)));
            this.classifiers.put(annotation, primalEstimatedSubGradient.getResult());
        }
    }

    private Collection<? extends InputOutputPair<? extends Vectorizable, Boolean>> convert(List<? extends FeatureVector> list, List<? extends FeatureVector> list2) {
        ArrayList arrayList = new ArrayList(list.size() + list2.size());
        Iterator<? extends FeatureVector> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(new DefaultInputOutputPair(convert(it.next()), true));
        }
        Iterator<? extends FeatureVector> it2 = list2.iterator();
        while (it2.hasNext()) {
            arrayList.add(new DefaultInputOutputPair(convert(it2.next()), false));
        }
        return arrayList;
    }

    @Override // org.openimaj.ml.annotation.Annotator
    public Set<ANNOTATION> getAnnotations() {
        return this.annotations;
    }

    @Override // org.openimaj.ml.annotation.Annotator
    public List<ScoredAnnotation<ANNOTATION>> annotate(OBJECT object) {
        ArrayList arrayList = new ArrayList();
        for (ANNOTATION annotation : this.annotations) {
            if (!annotation.equals(this.negativeClass)) {
                double evaluateAsDouble = this.classifiers.get(annotation).evaluateAsDouble(convert((FeatureVector) this.extractor.extractFeature(object)));
                if (evaluateAsDouble > 0.0d) {
                    arrayList.add(new ScoredAnnotation(annotation, (float) Math.abs(evaluateAsDouble)));
                }
            }
        }
        return arrayList;
    }

    private Vector convert(FeatureVector featureVector) {
        return VectorFactory.getDenseDefault().copyArray(featureVector.asDoubleVector());
    }
}
