package org.openimaj.ml.linear.learner;

import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.mtj.SparseMatrix;
import gov.sandia.cognition.math.matrix.mtj.SparseMatrixFactoryMTJ;
import java.util.HashMap;
import java.util.Map;
import org.openimaj.util.pair.IndependentPair;
import org.openimaj.util.pair.Pair;

/* loaded from: input_file:org/openimaj/ml/linear/learner/IncrementalBilinearSparseOnlineLearner.class */
public class IncrementalBilinearSparseOnlineLearner implements OnlineLearner<Map<String, Map<String, Double>>, Map<String, Double>> {
    private BiMap<String, Integer> vocabulary;
    private BiMap<String, Integer> users;
    private BiMap<String, Integer> values;
    private BilinearSparseOnlineLearner bilinearLearner;
    private BilinearLearnerParameters params;

    /* loaded from: input_file:org/openimaj/ml/linear/learner/IncrementalBilinearSparseOnlineLearner$IncrementalBilinearSparseOnlineLearnerParams.class */
    static class IncrementalBilinearSparseOnlineLearnerParams extends BilinearLearnerParameters {
        private static final long serialVersionUID = -1847045895118918210L;

        IncrementalBilinearSparseOnlineLearnerParams() {
        }
    }

    public IncrementalBilinearSparseOnlineLearner() {
        init(new IncrementalBilinearSparseOnlineLearnerParams());
    }

    public IncrementalBilinearSparseOnlineLearner(BilinearLearnerParameters bilinearLearnerParameters) {
        init(bilinearLearnerParameters);
    }

    public void reinitParams() {
        init(this.params);
    }

    private void init(BilinearLearnerParameters bilinearLearnerParameters) {
        this.vocabulary = HashBiMap.create();
        this.users = HashBiMap.create();
        this.values = HashBiMap.create();
        this.params = bilinearLearnerParameters;
        this.bilinearLearner = new BilinearSparseOnlineLearner(bilinearLearnerParameters);
    }

    public BilinearLearnerParameters getParams() {
        return this.params;
    }

    @Override // org.openimaj.ml.linear.learner.OnlineLearner
    public void process(Map<String, Map<String, Double>> map, Map<String, Double> map2) {
        updateUserValues(map, map2);
        Matrix constructYMatrix = constructYMatrix(map2);
        this.bilinearLearner.process(constructXMatrix(map), constructYMatrix);
    }

    public void updateUserValues(Map<String, Map<String, Double>> map, Map<String, Double> map2) {
        updateUserWords(map);
        updateValues(map2);
    }

    private void updateValues(Map<String, Double> map) {
        for (String str : map.keySet()) {
            if (!this.values.containsKey(str)) {
                this.values.put(str, Integer.valueOf(this.values.size()));
            }
        }
    }

    private Matrix constructYMatrix(Map<String, Double> map) {
        SparseMatrix createMatrix = SparseMatrixFactoryMTJ.INSTANCE.createMatrix(1, this.values.size());
        for (Map.Entry<String, Double> entry : map.entrySet()) {
            createMatrix.setElement(0, ((Integer) this.values.get(entry.getKey())).intValue(), entry.getValue().doubleValue());
        }
        return createMatrix;
    }

    private Map<String, Double> constructYMap(Matrix matrix) {
        HashMap hashMap = new HashMap();
        for (String str : this.values.keySet()) {
            hashMap.put(str, Double.valueOf(matrix.getElement(0, ((Integer) this.values.get(str)).intValue())));
        }
        return hashMap;
    }

    private Matrix constructXMatrix(Map<String, Map<String, Double>> map) {
        SparseMatrix createMatrix = SparseMatrixFactoryMTJ.INSTANCE.createMatrix(this.vocabulary.size(), this.users.size());
        for (Map.Entry<String, Map<String, Double>> entry : map.entrySet()) {
            int intValue = ((Integer) this.users.get(entry.getKey())).intValue();
            for (Map.Entry<String, Double> entry2 : entry.getValue().entrySet()) {
                createMatrix.setElement(((Integer) this.vocabulary.get(entry2.getKey())).intValue(), intValue, entry2.getValue().doubleValue());
            }
        }
        return createMatrix;
    }

    private void updateUserWords(Map<String, Map<String, Double>> map) {
        int i = 0;
        int i2 = 0;
        for (Map.Entry<String, Map<String, Double>> entry : map.entrySet()) {
            String key = entry.getKey();
            if (!this.users.containsKey(key)) {
                this.users.put(key, Integer.valueOf(this.users.size()));
                i++;
            }
            i2 += updateWords(entry.getValue());
        }
        this.bilinearLearner.addU(i);
        this.bilinearLearner.addW(i2);
    }

    private int updateWords(Map<String, Double> map) {
        int i = 0;
        for (String str : map.keySet()) {
            if (!this.vocabulary.containsKey(str)) {
                this.vocabulary.put(str, Integer.valueOf(this.vocabulary.size()));
                i++;
            }
        }
        return i;
    }

    public BilinearSparseOnlineLearner getBilinearLearner(int i, int i2) {
        BilinearSparseOnlineLearner m13clone = this.bilinearLearner.m13clone();
        Matrix u = m13clone.getU();
        Matrix w = m13clone.getW();
        SparseMatrix createMatrix = SparseMatrixFactoryMTJ.INSTANCE.createMatrix(i, u.getNumColumns());
        SparseMatrix createMatrix2 = SparseMatrixFactoryMTJ.INSTANCE.createMatrix(i2, w.getNumColumns());
        createMatrix.setSubMatrix(0, 0, u);
        createMatrix2.setSubMatrix(0, 0, w);
        m13clone.setU(createMatrix);
        m13clone.setW(createMatrix2);
        return m13clone;
    }

    public BilinearSparseOnlineLearner getBilinearLearner() {
        return this.bilinearLearner.m13clone();
    }

    public Pair<Matrix> asMatrixPair(IndependentPair<Map<String, Map<String, Double>>, Map<String, Double>> independentPair, int i, int i2, int i3) {
        SparseMatrix createMatrix = SparseMatrixFactoryMTJ.INSTANCE.createMatrix(1, i3);
        SparseMatrix createMatrix2 = SparseMatrixFactoryMTJ.INSTANCE.createMatrix(i, i2);
        Map map = (Map) independentPair.secondObject();
        Map map2 = (Map) independentPair.firstObject();
        for (Map.Entry entry : map.entrySet()) {
            createMatrix.setElement(0, ((Integer) this.values.get(entry.getKey())).intValue(), ((Double) entry.getValue()).doubleValue());
        }
        for (Map.Entry entry2 : map2.entrySet()) {
            int intValue = ((Integer) this.users.get(entry2.getKey())).intValue();
            for (Map.Entry entry3 : ((Map) entry2.getValue()).entrySet()) {
                createMatrix2.setElement(((Integer) this.vocabulary.get(entry3.getKey())).intValue(), intValue, ((Double) entry3.getValue()).doubleValue());
            }
        }
        return new Pair<>(createMatrix2, createMatrix);
    }

    @Override // org.openimaj.ml.linear.learner.OnlineLearner
    public Map<String, Double> predict(Map<String, Map<String, Double>> map) {
        return constructYMap(this.bilinearLearner.predict(constructXMatrix(map)));
    }

    public BiMap<String, Integer> getVocabulary() {
        return this.vocabulary;
    }

    public Pair<Matrix> asMatrixPair(IndependentPair<Map<String, Map<String, Double>>, Map<String, Double>> independentPair) {
        return asMatrixPair(independentPair, this.vocabulary.size(), this.users.size(), this.values.size());
    }

    public Pair<Matrix> asMatrixPair(Map<String, Map<String, Double>> map, Map<String, Double> map2) {
        return asMatrixPair(IndependentPair.pair(map, map2), this.vocabulary.size(), this.users.size(), this.values.size());
    }

    public BiMap<String, Integer> getDependantValues() {
        return this.values;
    }

    public BiMap<String, Integer> getUsers() {
        return this.users;
    }
}
