package org.openimaj.ml.annotation.svm;

import java.io.File;
import java.io.IOException;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import libsvm.svm;
import libsvm.svm_model;
import libsvm.svm_node;
import libsvm.svm_parameter;
import libsvm.svm_problem;
import org.openimaj.feature.FeatureExtractor;
import org.openimaj.feature.FeatureVector;
import org.openimaj.ml.annotation.Annotated;
import org.openimaj.ml.annotation.BatchAnnotator;
import org.openimaj.ml.annotation.ScoredAnnotation;
import org.openimaj.ml.annotation.utils.AnnotatedListHelper;
import org.openimaj.util.array.ArrayUtils;

/* loaded from: input_file:org/openimaj/ml/annotation/svm/SVMAnnotator.class */
public class SVMAnnotator<OBJECT, ANNOTATION> extends BatchAnnotator<OBJECT, ANNOTATION> {
    public static final int POSITIVE_CLASS = 1;
    public static final int NEGATIVE_CLASS = -1;
    private FeatureExtractor<? extends FeatureVector, OBJECT> extractor;
    public HashMap<Integer, ANNOTATION> classMap = new HashMap<>();
    private svm_model model = null;
    private File saveModel = null;

    public SVMAnnotator(FeatureExtractor<? extends FeatureVector, OBJECT> featureExtractor) {
        this.extractor = null;
        this.extractor = featureExtractor;
    }

    @Override // org.openimaj.ml.training.BatchTrainer
    public void train(List<? extends Annotated<OBJECT, ANNOTATION>> list) {
        if (checkInputDataOK(list)) {
            svm_parameter defaultSVMParameters = getDefaultSVMParameters();
            this.model = svm.svm_train(getSVMProblem(list, defaultSVMParameters, this.extractor), defaultSVMParameters);
            if (this.saveModel != null) {
                try {
                    svm.svm_save_model(this.saveModel.getAbsolutePath(), this.model);
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
        }
    }

    private boolean checkInputDataOK(List<? extends Annotated<OBJECT, ANNOTATION>> list) {
        this.classMap.clear();
        int i = 0;
        Iterator<? extends Annotated<OBJECT, ANNOTATION>> it = list.iterator();
        while (it.hasNext()) {
            Collection<ANNOTATION> annotations = it.next().getAnnotations();
            if (annotations.size() != 1) {
                throw new IllegalArgumentException("Data contained an object with more than one annotation");
            }
            ANNOTATION next = annotations.iterator().next();
            if (!this.classMap.values().contains(next)) {
                int i2 = (i * 2) - 1;
                i++;
                this.classMap.put(Integer.valueOf(i2), next);
            }
        }
        if (this.classMap.keySet().size() != 2) {
            throw new IllegalArgumentException("Data did not contain exactly 2 classes. It had " + this.classMap.keySet().size() + ". They were " + this.classMap);
        }
        return true;
    }

    @Override // org.openimaj.ml.annotation.Annotator
    public Set<ANNOTATION> getAnnotations() {
        HashSet hashSet = new HashSet();
        hashSet.addAll(this.classMap.values());
        return hashSet;
    }

    @Override // org.openimaj.ml.annotation.Annotator
    public List<ScoredAnnotation<ANNOTATION>> annotate(OBJECT object) {
        return Collections.singletonList(new ScoredAnnotation(svm.svm_predict(this.model, featureToNode((FeatureVector) this.extractor.extractFeature(object))) > 0.0d ? this.classMap.get(1) : this.classMap.get(-1), 1.0f));
    }

    public void setSaveModel(File file) {
        this.saveModel = file;
    }

    public void loadModel(File file) throws IOException {
        this.model = svm.svm_load_model(file.getAbsolutePath());
    }

    public double crossValidation(List<? extends Annotated<OBJECT, ANNOTATION>> list, int i) {
        svm_parameter defaultSVMParameters = getDefaultSVMParameters();
        return crossValidation(getSVMProblem(list, defaultSVMParameters, this.extractor), defaultSVMParameters, i);
    }

    public static double crossValidation(svm_problem svm_problemVar, svm_parameter svm_parameterVar, int i) {
        double[] dArr = new double[svm_problemVar.l];
        svm.svm_cross_validation(svm_problemVar, svm_parameterVar, i, dArr);
        int i2 = 0;
        for (int i3 = 0; i3 < svm_problemVar.l; i3++) {
            if (dArr[i3] == svm_problemVar.y[i3]) {
                i2++;
            }
        }
        double d = (100.0d * i2) / svm_problemVar.l;
        System.out.print("Cross Validation Accuracy = " + d + "%\n");
        return d;
    }

    private static svm_parameter getDefaultSVMParameters() {
        svm_parameter svm_parameterVar = new svm_parameter();
        svm_parameterVar.svm_type = 0;
        svm_parameterVar.kernel_type = 2;
        svm_parameterVar.degree = 3;
        svm_parameterVar.gamma = 0.0d;
        svm_parameterVar.coef0 = 0.0d;
        svm_parameterVar.nu = 0.5d;
        svm_parameterVar.cache_size = 100.0d;
        svm_parameterVar.C = 1.0d;
        svm_parameterVar.eps = 0.001d;
        svm_parameterVar.p = 0.1d;
        svm_parameterVar.shrinking = 1;
        svm_parameterVar.probability = 0;
        svm_parameterVar.nr_weight = 0;
        svm_parameterVar.weight_label = new int[0];
        svm_parameterVar.weight = new double[0];
        return svm_parameterVar;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private svm_problem getSVMProblem(List<? extends Annotated<OBJECT, ANNOTATION>> list, svm_parameter svm_parameterVar, FeatureExtractor<? extends FeatureVector, OBJECT> featureExtractor) {
        svm_node[][] computeFeature = computeFeature(list, this.classMap.get(1));
        svm_node[][] computeFeature2 = computeFeature(list, this.classMap.get(-1));
        int length = computeFeature.length + computeFeature2.length;
        double[] dArr = new double[length];
        ArrayUtils.fill(dArr, 1.0d, 0, computeFeature.length);
        ArrayUtils.fill(dArr, -1.0d, computeFeature.length, computeFeature2.length);
        svm_node[][] svm_nodeVarArr = (svm_node[][]) ArrayUtils.concatenate(new svm_node[][]{computeFeature, computeFeature2});
        svm_problem svm_problemVar = new svm_problem();
        svm_problemVar.l = length;
        svm_problemVar.x = svm_nodeVarArr;
        svm_problemVar.y = dArr;
        svm_parameterVar.gamma = 1.0d / getMaxIndex(svm_nodeVarArr);
        return svm_problemVar;
    }

    /* JADX WARN: Type inference failed for: r0v5, types: [libsvm.svm_node[], libsvm.svm_node[][]] */
    private svm_node[][] computeFeature(List<? extends Annotated<OBJECT, ANNOTATION>> list, ANNOTATION annotation) {
        List extractFeatures = new AnnotatedListHelper(list).extractFeatures(annotation, this.extractor);
        ?? r0 = new svm_node[extractFeatures.size()];
        int i = 0;
        Iterator it = extractFeatures.iterator();
        while (it.hasNext()) {
            int i2 = i;
            i++;
            r0[i2] = featureToNode((FeatureVector) it.next());
        }
        return r0;
    }

    private static int getMaxIndex(svm_node[][] svm_nodeVarArr) {
        int i = 0;
        for (svm_node[] svm_nodeVarArr2 : svm_nodeVarArr) {
            for (svm_node svm_nodeVar : svm_nodeVarArr2) {
                i = Math.max(i, svm_nodeVar.index);
            }
        }
        return i;
    }

    private static svm_node[] featureToNode(FeatureVector featureVector) {
        double[] asDoubleVector = featureVector.asDoubleVector();
        svm_node[] svm_nodeVarArr = new svm_node[asDoubleVector.length];
        for (int i = 0; i < asDoubleVector.length; i++) {
            svm_nodeVarArr[i] = new svm_node();
            svm_nodeVarArr[i].index = i;
            svm_nodeVarArr[i].value = asDoubleVector[i];
        }
        return svm_nodeVarArr;
    }
}
