package org.openimaj.ml.neuralnet;

import Jama.Matrix;
import gnu.trove.map.hash.TIntIntHashMap;
import gnu.trove.map.hash.TIntObjectHashMap;
import gnu.trove.procedure.TIntObjectProcedure;
import gov.sandia.cognition.algorithm.IterativeAlgorithm;
import gov.sandia.cognition.algorithm.IterativeAlgorithmListener;
import gov.sandia.cognition.io.CSVUtility;
import gov.sandia.cognition.learning.algorithm.gradient.GradientDescendable;
import gov.sandia.cognition.learning.algorithm.minimization.FunctionMinimizerLiuStorey;
import gov.sandia.cognition.learning.algorithm.regression.ParameterDifferentiableCostMinimizer;
import gov.sandia.cognition.learning.data.DefaultInputOutputPair;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.scalar.AtanFunction;
import gov.sandia.cognition.learning.function.vector.ThreeLayerFeedforwardNeuralNetwork;
import gov.sandia.cognition.math.DifferentiableUnivariateScalarFunction;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorEntry;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.mtj.DenseVectorFactoryMTJ;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.List;
import org.openimaj.util.pair.IndependentPair;

/* loaded from: input_file:org/openimaj/ml/neuralnet/HandWritingNeuralNetSANDIA.class */
public class HandWritingNeuralNetSANDIA implements IterativeAlgorithmListener {
    public static final String INPUT_LOCATION = "/org/openimaj/ml/handwriting/inputs.csv";
    public static final String OUTPUT_LOCATION = "/org/openimaj/ml/handwriting/outputs.csv";
    private Matrix xVals;
    private Matrix yVals;
    private ArrayList<InputOutputPair<Vector, Vector>> dataCollection;
    private TIntIntHashMap examples;
    private TIntObjectHashMap<List<IndependentPair<Vector, Vector>>> tests;
    private GradientDescendable neuralNet;
    private int maxExamples = 400;
    private int maxTests = 10;
    private int nHiddenLayer = 20;
    private int totalTests = 0;

    public HandWritingNeuralNetSANDIA() throws IOException {
        BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(HandWritingNeuralNetSANDIA.class.getResourceAsStream(INPUT_LOCATION)));
        BufferedReader bufferedReader2 = new BufferedReader(new InputStreamReader(HandWritingNeuralNetSANDIA.class.getResourceAsStream(OUTPUT_LOCATION)));
        this.xVals = fromCSV(bufferedReader, 5000);
        this.yVals = fromCSV(bufferedReader2, 5000);
        this.examples = new TIntIntHashMap();
        this.tests = new TIntObjectHashMap<>();
        prepareDataCollection();
        learnNeuralNet();
        testNeuralNet();
    }

    /* JADX WARN: Type inference failed for: r0v2, types: [double[], double[][]] */
    private void testNeuralNet() {
        final ?? r0 = new double[this.totalTests];
        final int[] iArr = new int[this.totalTests];
        this.tests.forEachEntry(new TIntObjectProcedure<List<IndependentPair<Vector, Vector>>>() { // from class: org.openimaj.ml.neuralnet.HandWritingNeuralNetSANDIA.1
            int done = 0;
            DenseVectorFactoryMTJ fact = new DenseVectorFactoryMTJ();

            public boolean execute(int i, List<IndependentPair<Vector, Vector>> list) {
                for (IndependentPair<Vector, Vector> independentPair : list) {
                    int i2 = 0;
                    double d = 0.0d;
                    for (VectorEntry vectorEntry : (Vector) HandWritingNeuralNetSANDIA.this.neuralNet.evaluate(independentPair.firstObject())) {
                        if (d < vectorEntry.getValue()) {
                            d = vectorEntry.getValue();
                            i2 = vectorEntry.getIndex();
                        }
                    }
                    r0[this.done] = this.fact.copyVector((Vector) independentPair.firstObject()).getArray();
                    iArr[this.done] = i2;
                    this.done++;
                }
                return true;
            }
        });
        new HandWritingInputDisplay(r0, iArr);
    }

    private void prepareDataCollection() {
        this.dataCollection = new ArrayList<>();
        double[][] array = this.xVals.getArray();
        double[][] array2 = this.yVals.getArray();
        for (int i = 0; i < array.length; i++) {
            Vector copyArray = VectorFactory.getDefault().copyArray(array[i]);
            double[] dArr = new double[10];
            int i2 = (int) (array2[i][0] % 10.0d);
            int adjustOrPutValue = this.examples.adjustOrPutValue(i2, 1, 1);
            dArr[i2] = 1.0d;
            Vector copyValues = VectorFactory.getDefault().copyValues(dArr);
            if (this.maxExamples == -1 || adjustOrPutValue <= this.maxExamples) {
                this.dataCollection.add(DefaultInputOutputPair.create(copyArray, copyValues));
            } else if (adjustOrPutValue <= this.maxTests + this.maxExamples) {
                List list = (List) this.tests.get(i2);
                if (list == null) {
                    TIntObjectHashMap<List<IndependentPair<Vector, Vector>>> tIntObjectHashMap = this.tests;
                    ArrayList arrayList = new ArrayList();
                    list = arrayList;
                    tIntObjectHashMap.put(i2, arrayList);
                }
                list.add(IndependentPair.pair(copyArray, copyValues));
                this.totalTests++;
            }
        }
    }

    private void learnNeuralNet() {
        toArrayList(new Integer[]{Integer.valueOf(this.xVals.getColumnDimension()), Integer.valueOf(this.nHiddenLayer), 10});
        toArrayList(new DifferentiableUnivariateScalarFunction[]{new AtanFunction(), new AtanFunction()});
        ThreeLayerFeedforwardNeuralNetwork threeLayerFeedforwardNeuralNetwork = new ThreeLayerFeedforwardNeuralNetwork(this.xVals.getColumnDimension(), this.nHiddenLayer, 10);
        ParameterDifferentiableCostMinimizer parameterDifferentiableCostMinimizer = new ParameterDifferentiableCostMinimizer(new FunctionMinimizerLiuStorey());
        parameterDifferentiableCostMinimizer.setObjectToOptimize(threeLayerFeedforwardNeuralNetwork);
        parameterDifferentiableCostMinimizer.addIterativeAlgorithmListener(this);
        parameterDifferentiableCostMinimizer.setMaxIterations(50);
        this.neuralNet = parameterDifferentiableCostMinimizer.learn(this.dataCollection);
    }

    private static <T> ArrayList<T> toArrayList(T[] tArr) {
        ArrayList<T> arrayList = new ArrayList<>();
        for (T t : tArr) {
            arrayList.add(t);
        }
        return arrayList;
    }

    private Matrix fromCSV(BufferedReader bufferedReader, int i) throws IOException {
        double[][] dArr = (double[][]) null;
        Matrix matrix = null;
        int i2 = 0;
        while (true) {
            String[] nextNonEmptyLine = CSVUtility.nextNonEmptyLine(bufferedReader);
            if (nextNonEmptyLine == null) {
                return matrix;
            }
            if (dArr == null) {
                matrix = new Matrix(i, nextNonEmptyLine.length);
                dArr = matrix.getArray();
            }
            for (int i3 = 0; i3 < nextNonEmptyLine.length; i3++) {
                dArr[i2][i3] = Double.parseDouble(nextNonEmptyLine[i3]);
            }
            i2++;
        }
    }

    public static void main(String[] strArr) throws IOException {
        new HandWritingNeuralNetSANDIA();
    }

    public void algorithmStarted(IterativeAlgorithm iterativeAlgorithm) {
        System.out.println("Learning neural network");
    }

    public void algorithmEnded(IterativeAlgorithm iterativeAlgorithm) {
        System.out.println("Done Learning!");
    }

    public void stepStarted(IterativeAlgorithm iterativeAlgorithm) {
        System.out.println("... starting step: " + iterativeAlgorithm.getIteration());
    }

    public void stepEnded(IterativeAlgorithm iterativeAlgorithm) {
        System.out.println("... ending step: " + iterativeAlgorithm.getIteration());
    }
}
