package org.openimaj.ml.annotation.linear;

import Jama.Matrix;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import org.openimaj.citation.annotation.Reference;
import org.openimaj.citation.annotation.ReferenceType;
import org.openimaj.feature.FeatureExtractor;
import org.openimaj.feature.FeatureVector;
import org.openimaj.math.matrix.PseudoInverse;
import org.openimaj.ml.annotation.Annotated;
import org.openimaj.ml.annotation.BatchAnnotator;
import org.openimaj.ml.annotation.ScoredAnnotation;

@Reference(type = ReferenceType.Inproceedings, author = {"Jonathan Hare", "Paul Lewis"}, title = "Semantic Retrieval and Automatic Annotation: Linear Transformations, Correlation and Semantic Spaces", year = "2010", booktitle = "Imaging and Printing in a Web 2.0 World; and Multimedia Content Access: Algorithms and Systems IV", url = "http://eprints.soton.ac.uk/268496/", note = " Event Dates: 17-21 Jan 2010", month = "January", publisher = "SPIE", volume = "7540")
/* loaded from: input_file:org/openimaj/ml/annotation/linear/DenseLinearTransformAnnotator.class */
public class DenseLinearTransformAnnotator<OBJECT, ANNOTATION> extends BatchAnnotator<OBJECT, ANNOTATION> {
    protected List<ANNOTATION> terms;
    protected Matrix transform;
    protected int k;
    private FeatureExtractor<? extends FeatureVector, OBJECT> extractor;

    public DenseLinearTransformAnnotator(int i, FeatureExtractor<? extends FeatureVector, OBJECT> featureExtractor) {
        this.k = 10;
        this.extractor = featureExtractor;
        this.k = i;
    }

    @Override // org.openimaj.ml.training.BatchTrainer
    public void train(List<? extends Annotated<OBJECT, ANNOTATION>> list) {
        HashSet hashSet = new HashSet();
        Iterator<? extends Annotated<OBJECT, ANNOTATION>> it = list.iterator();
        while (it.hasNext()) {
            hashSet.addAll(it.next().getAnnotations());
        }
        this.terms = new ArrayList(hashSet);
        int size = this.terms.size();
        int size2 = list.size();
        Annotated<OBJECT, ANNOTATION> annotated = list.get(0);
        double[] asDoubleVector = ((FeatureVector) this.extractor.extractFeature(annotated.getObject())).asDoubleVector();
        Matrix matrix = new Matrix(size2, asDoubleVector.length);
        Matrix matrix2 = new Matrix(size2, size);
        addRow(matrix, matrix2, 0, asDoubleVector, annotated.getAnnotations());
        for (int i = 1; i < size2; i++) {
            addRow(matrix, matrix2, i, list.get(i));
        }
        this.transform = PseudoInverse.pseudoInverse(matrix, this.k).times(matrix2);
    }

    private void addRow(Matrix matrix, Matrix matrix2, int i, Annotated<OBJECT, ANNOTATION> annotated) {
        addRow(matrix, matrix2, i, ((FeatureVector) this.extractor.extractFeature(annotated.getObject())).asDoubleVector(), annotated.getAnnotations());
    }

    private void addRow(Matrix matrix, Matrix matrix2, int i, double[] dArr, Collection<ANNOTATION> collection) {
        for (int i2 = 0; i2 < matrix.getColumnDimension(); i2++) {
            matrix.getArray()[i][i2] = dArr[i2];
        }
        for (ANNOTATION annotation : collection) {
            double[] dArr2 = matrix2.getArray()[i];
            int indexOf = this.terms.indexOf(annotation);
            dArr2[indexOf] = dArr2[indexOf] + 1.0d;
        }
    }

    /* JADX WARN: Type inference failed for: r2v1, types: [double[], double[][]] */
    @Override // org.openimaj.ml.annotation.Annotator
    public List<ScoredAnnotation<ANNOTATION>> annotate(OBJECT object) {
        Matrix times = new Matrix((double[][]) new double[]{((FeatureVector) this.extractor.extractFeature(object)).asDoubleVector()}).times(this.transform);
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.terms.size(); i++) {
            arrayList.add(new ScoredAnnotation(this.terms.get(i), (float) times.get(0, i)));
        }
        Collections.sort(arrayList, new Comparator<ScoredAnnotation<ANNOTATION>>() { // from class: org.openimaj.ml.annotation.linear.DenseLinearTransformAnnotator.1
            @Override // java.util.Comparator
            public int compare(ScoredAnnotation<ANNOTATION> scoredAnnotation, ScoredAnnotation<ANNOTATION> scoredAnnotation2) {
                return scoredAnnotation.confidence < scoredAnnotation2.confidence ? 1 : -1;
            }
        });
        return arrayList;
    }

    @Override // org.openimaj.ml.annotation.Annotator
    public Set<ANNOTATION> getAnnotations() {
        return new HashSet(this.terms);
    }
}
