package org.openimaj.demos.sandbox.ml.linear.learner.stream;

import com.google.common.collect.BiMap;
import gov.sandia.cognition.math.matrix.Matrix;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import org.openimaj.math.matrix.CFMatrixUtils;
import org.openimaj.ml.linear.evaluation.BilinearEvaluator;
import org.openimaj.ml.linear.learner.BilinearSparseOnlineLearner;
import org.openimaj.ml.linear.learner.IncrementalBilinearSparseOnlineLearner;
import org.openimaj.util.data.Context;
import org.openimaj.util.pair.Pair;

/* loaded from: input_file:org/openimaj/demos/sandbox/ml/linear/learner/stream/ModelStats.class */
public class ModelStats {
    public double score;
    public Map<String, SortedImportantWords> importantWords;
    public IncrementalBilinearSparseOnlineLearner learner;
    public Map<String, Pair<Double>> taskWordMinMax;
    public Matrix correctY;
    public Matrix estimatedY;
    public Matrix bias;
    public Map<String, Pair<Double>> userMinMax;

    public ModelStats() {
        this.score = 0.0d;
        this.learner = null;
        this.importantWords = new HashMap();
        this.taskWordMinMax = new HashMap();
    }

    public ModelStats(BilinearEvaluator bilinearEvaluator, IncrementalBilinearSparseOnlineLearner incrementalBilinearSparseOnlineLearner, Context context) {
        Map map = (Map) context.getTyped("bagofwords");
        Map map2 = (Map) context.getTyped("averageticks");
        this.learner = incrementalBilinearSparseOnlineLearner;
        this.learner.updateUserValues(map, map2);
        BilinearSparseOnlineLearner bilinearLearner = this.learner.getBilinearLearner();
        bilinearEvaluator.setLearner(bilinearLearner);
        ArrayList arrayList = new ArrayList();
        Pair asMatrixPair = this.learner.asMatrixPair(map, map2);
        arrayList.add(asMatrixPair);
        this.score = bilinearEvaluator.evaluate(arrayList);
        this.importantWords = importantWords();
        this.taskWordMinMax = minMaxWords();
        this.userMinMax = minMaxUsers();
        this.correctY = (Matrix) asMatrixPair.secondObject();
        this.bias = bilinearLearner.getBias();
        this.estimatedY = bilinearLearner.predict((Matrix) asMatrixPair.firstObject());
    }

    private Map<String, Pair<Double>> minMaxUsers() {
        HashMap hashMap = new HashMap();
        if (this.learner == null) {
            return hashMap;
        }
        BiMap dependantValues = this.learner.getDependantValues();
        BilinearSparseOnlineLearner bilinearLearner = this.learner.getBilinearLearner();
        for (String str : dependantValues.keySet()) {
            Integer num = (Integer) this.learner.getDependantValues().get(str);
            hashMap.put(str, new Pair(Double.valueOf(CFMatrixUtils.min(bilinearLearner.getU().getColumn(num.intValue()))), Double.valueOf(CFMatrixUtils.max(bilinearLearner.getU().getColumn(num.intValue())))));
        }
        return hashMap;
    }

    private Map<String, Pair<Double>> minMaxWords() {
        HashMap hashMap = new HashMap();
        if (this.learner == null) {
            return hashMap;
        }
        BiMap dependantValues = this.learner.getDependantValues();
        BilinearSparseOnlineLearner bilinearLearner = this.learner.getBilinearLearner();
        for (String str : dependantValues.keySet()) {
            Integer num = (Integer) this.learner.getDependantValues().get(str);
            hashMap.put(str, new Pair(Double.valueOf(CFMatrixUtils.min(bilinearLearner.getW().getColumn(num.intValue()))), Double.valueOf(CFMatrixUtils.max(bilinearLearner.getW().getColumn(num.intValue())))));
        }
        return hashMap;
    }

    private Map<String, SortedImportantWords> importantWords() {
        HashMap hashMap = new HashMap();
        if (this.learner == null) {
            return hashMap;
        }
        BiMap dependantValues = this.learner.getDependantValues();
        BilinearSparseOnlineLearner bilinearLearner = this.learner.getBilinearLearner();
        for (String str : dependantValues.keySet()) {
            hashMap.put(str, new SortedImportantWords(str, this.learner, bilinearLearner, 10));
        }
        return hashMap;
    }

    public void printSummary() {
        if (this.learner == null) {
            System.out.println("No loss!");
            return;
        }
        System.out.println("Loss: " + this.score);
        System.out.println("Important words: ");
        BilinearSparseOnlineLearner bilinearLearner = this.learner.getBilinearLearner();
        BiMap inverse = this.learner.getVocabulary().inverse();
        for (String str : this.importantWords.keySet()) {
            Pair<Double> pair = this.taskWordMinMax.get(str);
            SortedImportantWords sortedImportantWords = this.importantWords.get(str);
            for (int i : sortedImportantWords.indexes) {
                System.out.println("Word: " + ((String) inverse.get(Integer.valueOf(i))) + " index " + i);
                System.out.println(bilinearLearner.getW().getRow(i));
            }
            System.out.printf("... %s (%1.4f->%1.4f) %s\n", str, pair.firstObject(), pair.secondObject(), sortedImportantWords);
        }
        System.out.println("User importance: ");
        for (String str2 : this.importantWords.keySet()) {
            Pair<Double> pair2 = this.userMinMax.get(str2);
            System.out.printf("... %s (%1.4f->%1.4f)\n", str2, pair2.firstObject(), pair2.secondObject());
        }
        System.out.println("Model Bias: \n" + this.bias);
        System.out.println("Correct Y: \n" + this.correctY);
        System.out.println("Estimated Y: \n" + this.estimatedY);
    }
}
