package org.openimaj.ml.neuralnet;

import Jama.Matrix;
import gnu.trove.map.hash.TIntIntHashMap;
import gnu.trove.map.hash.TIntObjectHashMap;
import gnu.trove.procedure.TIntObjectProcedure;
import gov.sandia.cognition.io.CSVUtility;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.List;
import org.encog.engine.network.activation.ActivationStep;
import org.encog.mathutil.rbf.RBFEnum;
import org.encog.ml.MLRegression;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.specific.CSVNeuralDataSet;
import org.encog.ml.svm.SVM;
import org.encog.ml.svm.training.SVMTrain;
import org.encog.neural.cpn.CPN;
import org.encog.neural.cpn.training.TrainInstar;
import org.encog.neural.cpn.training.TrainOutstar;
import org.encog.neural.data.basic.BasicNeuralData;
import org.encog.neural.neat.NEATPopulation;
import org.encog.neural.neat.training.NEATTraining;
import org.encog.neural.networks.training.TrainingSetScore;
import org.encog.neural.networks.training.propagation.resilient.ResilientPropagation;
import org.encog.neural.rbf.RBFNetwork;
import org.encog.util.simple.EncogUtility;
import org.openimaj.util.pair.IndependentPair;

/* loaded from: input_file:org/openimaj/ml/neuralnet/HandWritingNeuralNetENCOG.class */
public class HandWritingNeuralNetENCOG {
    public static final String INPUT_LOCATION = "/org/openimaj/ml/handwriting/inputouput.csv";
    private MLRegression network;
    private MLDataSet training;
    private int maxTests = 10;
    private int totalTests = 0;
    private TIntIntHashMap examples = new TIntIntHashMap();
    private TIntObjectHashMap<List<IndependentPair<double[], double[]>>> tests = new TIntObjectHashMap<>();

    public HandWritingNeuralNetENCOG() throws IOException {
        prepareDataCollection();
        learnNeuralNet();
        testNeuralNet();
    }

    /* JADX WARN: Type inference failed for: r0v2, types: [double[], double[][]] */
    private void testNeuralNet() {
        final ?? r0 = new double[this.totalTests];
        final int[] iArr = new int[this.totalTests];
        this.tests.forEachEntry(new TIntObjectProcedure<List<IndependentPair<double[], double[]>>>() { // from class: org.openimaj.ml.neuralnet.HandWritingNeuralNetENCOG.1
            int done = 0;

            public boolean execute(int i, List<IndependentPair<double[], double[]>> list) {
                for (IndependentPair<double[], double[]> independentPair : list) {
                    double[] data = HandWritingNeuralNetENCOG.this.network.compute(new BasicNeuralData((double[]) independentPair.firstObject())).getData();
                    int i2 = 0;
                    double d = 0.0d;
                    for (int i3 = 0; i3 < data.length; i3++) {
                        if (d < data[i3]) {
                            d = data[i3];
                            i2 = i3;
                        }
                    }
                    r0[this.done] = (double[]) independentPair.firstObject();
                    iArr[this.done] = (i2 + 1) % 10;
                    this.done++;
                }
                return true;
            }
        });
        new HandWritingInputDisplay(r0, iArr);
    }

    private void prepareDataCollection() throws IOException {
        File createTempFile = File.createTempFile("data", ".csv");
        BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(HandWritingNeuralNetENCOG.class.getResourceAsStream(INPUT_LOCATION)));
        PrintWriter printWriter = new PrintWriter(new FileWriter(createTempFile));
        while (true) {
            String readLine = bufferedReader.readLine();
            if (readLine == null) {
                break;
            } else {
                printWriter.println(readLine);
            }
        }
        printWriter.close();
        bufferedReader.close();
        this.training = new CSVNeuralDataSet(createTempFile.getAbsolutePath(), 400, 10, false);
        for (MLDataPair mLDataPair : this.training) {
            double[] idealArray = mLDataPair.getIdealArray();
            double[] inputArray = mLDataPair.getInputArray();
            int i = 0;
            while (idealArray[i] != 1.0d) {
                i++;
            }
            if (this.examples.adjustOrPutValue(i, 1, 1) < this.maxTests) {
                List list = (List) this.tests.get(i);
                if (list == null) {
                    ArrayList arrayList = new ArrayList();
                    list = arrayList;
                    this.tests.put(i, arrayList);
                }
                list.add(IndependentPair.pair(inputArray, idealArray));
                this.totalTests++;
            }
        }
    }

    private void learnNeuralNet() {
        this.network = withResilieant();
    }

    private MLRegression withNEAT() {
        NEATPopulation nEATPopulation = new NEATPopulation(400, 10, 1000);
        TrainingSetScore trainingSetScore = new TrainingSetScore(this.training);
        ActivationStep activationStep = new ActivationStep();
        activationStep.setCenter(0.5d);
        nEATPopulation.setOutputActivationFunction(activationStep);
        NEATTraining nEATTraining = new NEATTraining(trainingSetScore, nEATPopulation);
        EncogUtility.trainToError(nEATTraining, 0.01515d);
        return nEATTraining.getMethod();
    }

    private MLRegression withResilieant() {
        ResilientPropagation resilientPropagation = new ResilientPropagation(EncogUtility.simpleFeedForward(400, 100, 0, 10, false), this.training);
        EncogUtility.trainToError(resilientPropagation, 0.01515d);
        return resilientPropagation.getMethod();
    }

    private MLRegression withSVM() {
        SVMTrain sVMTrain = new SVMTrain(new SVM(400, true), this.training);
        EncogUtility.trainToError(sVMTrain, 0.01515d);
        return sVMTrain.getMethod();
    }

    private MLRegression withRBF() {
        RBFNetwork rBFNetwork = new RBFNetwork(400, 20, 10, RBFEnum.Gaussian);
        EncogUtility.trainToError(rBFNetwork, this.training, 0.01515d);
        return rBFNetwork;
    }

    private MLRegression withCPN() {
        CPN cpn = new CPN(400, 1000, 10, 1);
        EncogUtility.trainToError(new TrainInstar(cpn, this.training, 0.1d, false), 0.01515d);
        EncogUtility.trainToError(new TrainOutstar(cpn, this.training, 0.1d), 0.01515d);
        return cpn;
    }

    private static <T> ArrayList<T> toArrayList(T[] tArr) {
        ArrayList<T> arrayList = new ArrayList<>();
        for (T t : tArr) {
            arrayList.add(t);
        }
        return arrayList;
    }

    private Matrix fromCSV(BufferedReader bufferedReader, int i) throws IOException {
        double[][] dArr = (double[][]) null;
        Matrix matrix = null;
        int i2 = 0;
        while (true) {
            String[] nextNonEmptyLine = CSVUtility.nextNonEmptyLine(bufferedReader);
            if (nextNonEmptyLine == null) {
                return matrix;
            }
            if (dArr == null) {
                matrix = new Matrix(i, nextNonEmptyLine.length);
                dArr = matrix.getArray();
            }
            for (int i3 = 0; i3 < nextNonEmptyLine.length; i3++) {
                dArr[i2][i3] = Double.parseDouble(nextNonEmptyLine[i3]);
            }
            i2++;
        }
    }

    public static void main(String[] strArr) throws IOException {
        new HandWritingNeuralNetENCOG();
    }
}
