package org.openimaj.ml.linear.experiments.sinabill;

import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import org.openimaj.io.IOUtils;
import org.openimaj.math.matrix.CFMatrixUtils;
import org.openimaj.ml.linear.data.BillMatlabFileDataGenerator;
import org.openimaj.ml.linear.evaluation.RootMeanSumLossEvaluator;
import org.openimaj.ml.linear.learner.BilinearLearnerParameters;
import org.openimaj.ml.linear.learner.BilinearSparseOnlineLearner;
import org.openimaj.ml.linear.learner.init.SingleValueInitStrat;
import org.openimaj.ml.linear.learner.init.SparseZerosInitStrategy;
import org.openimaj.util.pair.Pair;

/* loaded from: input_file:org/openimaj/ml/linear/experiments/sinabill/BillAustrianDampeningExperiments.class */
public class BillAustrianDampeningExperiments extends BilinearExperiment {
    public static void main(String[] strArr) throws Exception {
        new BillAustrianDampeningExperiments().performExperiment();
    }

    @Override // org.openimaj.ml.linear.experiments.sinabill.BilinearExperiment
    public void performExperiment() throws IOException {
        BilinearLearnerParameters bilinearLearnerParameters = new BilinearLearnerParameters();
        bilinearLearnerParameters.put(BilinearLearnerParameters.ETA0_U, Double.valueOf(0.02d));
        bilinearLearnerParameters.put(BilinearLearnerParameters.ETA0_W, Double.valueOf(0.02d));
        bilinearLearnerParameters.put(BilinearLearnerParameters.LAMBDA, Double.valueOf(0.001d));
        bilinearLearnerParameters.put(BilinearLearnerParameters.BICONVEX_TOL, Double.valueOf(0.01d));
        bilinearLearnerParameters.put(BilinearLearnerParameters.BICONVEX_MAXITER, 10);
        bilinearLearnerParameters.put(BilinearLearnerParameters.BIAS, true);
        bilinearLearnerParameters.put(BilinearLearnerParameters.ETA0_BIAS, Double.valueOf(0.5d));
        bilinearLearnerParameters.put(BilinearLearnerParameters.WINITSTRAT, new SingleValueInitStrat(0.1d));
        bilinearLearnerParameters.put(BilinearLearnerParameters.UINITSTRAT, new SparseZerosInitStrategy());
        BillMatlabFileDataGenerator billMatlabFileDataGenerator = new BillMatlabFileDataGenerator(new File(MATLAB_DATA()), 98, true);
        prepareExperimentLog(bilinearLearnerParameters);
        this.logger.debug("Starting dampening experiments");
        this.logger.debug("Fold: 5");
        billMatlabFileDataGenerator.setFold(5, BillMatlabFileDataGenerator.Mode.TEST);
        ArrayList arrayList = new ArrayList();
        while (true) {
            Pair<Matrix> mo9generate = billMatlabFileDataGenerator.mo9generate();
            if (mo9generate == null) {
                break;
            } else {
                arrayList.add(mo9generate);
            }
        }
        this.logger.debug(String.format("Beggining dampening experiments: min=%2.5f,max=%2.5f,incr=%2.5f", Double.valueOf(0.0d), Double.valueOf(0.02d), Double.valueOf(1.0E-4d)));
        for (double d = 0.0d; d < 0.02d; d += 1.0E-4d) {
            bilinearLearnerParameters.put(BilinearLearnerParameters.DAMPENING, Double.valueOf(d));
            BilinearSparseOnlineLearner bilinearSparseOnlineLearner = new BilinearSparseOnlineLearner(bilinearLearnerParameters);
            bilinearSparseOnlineLearner.reinitParams();
            this.logger.debug("Dampening is now: " + d);
            this.logger.debug("...training");
            billMatlabFileDataGenerator.setFold(5, BillMatlabFileDataGenerator.Mode.TRAINING);
            int i = 0;
            while (true) {
                Pair<Matrix> mo9generate2 = billMatlabFileDataGenerator.mo9generate();
                if (mo9generate2 == null) {
                    break;
                }
                int i2 = i;
                i++;
                this.logger.debug("...trying item " + i2);
                bilinearSparseOnlineLearner.process((Matrix) mo9generate2.firstObject(), (Matrix) mo9generate2.secondObject());
                Matrix u = bilinearSparseOnlineLearner.getU();
                Matrix w = bilinearSparseOnlineLearner.getW();
                Matrix copyMatrix = MatrixFactory.getDenseDefault().copyMatrix(bilinearSparseOnlineLearner.getBias());
                RootMeanSumLossEvaluator rootMeanSumLossEvaluator = new RootMeanSumLossEvaluator();
                rootMeanSumLossEvaluator.setLearner(bilinearSparseOnlineLearner);
                double evaluate = rootMeanSumLossEvaluator.evaluate(arrayList);
                this.logger.debug(String.format("Saving learner, Fold %d, Item %d", 5, Integer.valueOf(i)));
                IOUtils.writeBinary(new File(FOLD_ROOT(5), String.format("learner_%d_dampening=%2.5f", Integer.valueOf(i), Double.valueOf(d))), bilinearSparseOnlineLearner);
                this.logger.debug("W row sparcity: " + CFMatrixUtils.rowSparsity(w));
                this.logger.debug(String.format("W range: %2.5f -> %2.5f", Double.valueOf(CFMatrixUtils.min(w)), Double.valueOf(CFMatrixUtils.max(w))));
                this.logger.debug("U row sparcity: " + CFMatrixUtils.rowSparsity(u));
                this.logger.debug(String.format("U range: %2.5f -> %2.5f", Double.valueOf(CFMatrixUtils.min(u)), Double.valueOf(CFMatrixUtils.max(u))));
                if (((Boolean) bilinearSparseOnlineLearner.getParams().getTyped(BilinearLearnerParameters.BIAS)).booleanValue()) {
                    this.logger.debug("Bias: " + CFMatrixUtils.diag(copyMatrix));
                }
                this.logger.debug(String.format("... loss: %f", Double.valueOf(evaluate)));
            }
        }
    }
}
