package org.openimaj.demos.ml.linear.data;

import gov.sandia.cognition.learning.algorithm.perceptron.kernel.KernelPerceptron;
import gov.sandia.cognition.learning.data.DefaultInputOutputPair;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.kernel.LinearKernel;
import gov.sandia.cognition.math.matrix.VectorFactory;
import java.io.File;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import no.uib.cipr.matrix.Vector;
import org.openimaj.image.DisplayUtilities;
import org.openimaj.image.MBFImage;
import org.openimaj.image.colour.ColourSpace;
import org.openimaj.image.colour.RGBColour;
import org.openimaj.math.geometry.line.Line2d;
import org.openimaj.math.geometry.point.Point2d;
import org.openimaj.math.geometry.point.Point2dImpl;
import org.openimaj.math.geometry.shape.Circle;
import org.openimaj.ml.linear.data.LinearPerceptronDataGenerator;
import org.openimaj.ml.linear.kernel.LinearVectorKernel;
import org.openimaj.ml.linear.learner.perceptron.MatrixKernelPerceptron;
import org.openimaj.ml.linear.learner.perceptron.MeanCenteredKernelPerceptron;
import org.openimaj.ml.linear.learner.perceptron.MeanCenteredProjectron;
import org.openimaj.ml.linear.learner.perceptron.PerceptronClass;
import org.openimaj.ml.linear.learner.perceptron.SimplePerceptron;
import org.openimaj.util.pair.IndependentPair;
import org.openimaj.util.stream.Stream;

/* loaded from: input_file:org/openimaj/demos/ml/linear/data/DrawLinearData.class */
public class DrawLinearData {
    private static final int TOTAL_DATA_ITEMS = 1000;
    private static final int SEED = 1;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.openimaj.demos.ml.linear.data.DrawLinearData$1, reason: invalid class name */
    /* loaded from: input_file:org/openimaj/demos/ml/linear/data/DrawLinearData$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$openimaj$ml$linear$learner$perceptron$PerceptronClass = new int[PerceptronClass.values().length];

        static {
            try {
                $SwitchMap$org$openimaj$ml$linear$learner$perceptron$PerceptronClass[PerceptronClass.TRUE.ordinal()] = DrawLinearData.SEED;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$openimaj$ml$linear$learner$perceptron$PerceptronClass[PerceptronClass.FALSE.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$openimaj$ml$linear$learner$perceptron$PerceptronClass[PerceptronClass.NONE.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
        }
    }

    public static void main(String[] strArr) throws IOException {
        drawData(dataGen());
        writeData(new File("/Users/ss/Experiments/perceptron/test.data"));
        leanrnPoints((MatrixKernelPerceptron) new MeanCenteredKernelPerceptron(new LinearVectorKernel()), (Iterable<IndependentPair<double[], PerceptronClass>>) new RepeatingDataStream(dataGen(), TOTAL_DATA_ITEMS));
    }

    private static void writeData(File file) throws IOException {
        LinearPerceptronDataGenerator dataGen = dataGen();
        File parentFile = file.getParentFile();
        if (!parentFile.exists()) {
            parentFile.mkdirs();
        }
        PrintWriter printWriter = new PrintWriter(file);
        for (int i = 0; i < TOTAL_DATA_ITEMS; i += SEED) {
            IndependentPair generate = dataGen.generate();
            printWriter.println(Arrays.toString((double[]) generate.firstObject()));
            printWriter.println(generate.secondObject() == PerceptronClass.TRUE ? SEED : 0);
        }
        printWriter.close();
    }

    private static void drawData(LinearPerceptronDataGenerator linearPerceptronDataGenerator) {
        LimitedDataStream limitedDataStream = new LimitedDataStream(linearPerceptronDataGenerator, TOTAL_DATA_ITEMS);
        Vector origin = linearPerceptronDataGenerator.getOrigin();
        Vector vector = linearPerceptronDataGenerator.getPlane()[0];
        drawPoints(limitedDataStream, new Line2d(start(origin, vector), end(origin, vector)));
    }

    private static LinearPerceptronDataGenerator dataGen() {
        return new LinearPerceptronDataGenerator(300.0d, 2, 0.3d, SEED);
    }

    private static void leanrnCogFound() {
        dataGen();
        new KernelPerceptron(new LinearKernel()).learn(createData());
    }

    private static Collection<? extends InputOutputPair<? extends gov.sandia.cognition.math.matrix.Vector, Boolean>> createData() {
        ArrayList arrayList = new ArrayList();
        LinearPerceptronDataGenerator dataGen = dataGen();
        for (int i = 0; i < TOTAL_DATA_ITEMS; i += SEED) {
            IndependentPair generate = dataGen.generate();
            arrayList.add(DefaultInputOutputPair.create(VectorFactory.getDenseDefault().copyArray((double[]) generate.firstObject()), Boolean.valueOf(((PerceptronClass) generate.secondObject()).equals(PerceptronClass.TRUE))));
        }
        System.out.println("Data created");
        return arrayList;
    }

    private static void drawMkpLine(MatrixKernelPerceptron matrixKernelPerceptron) {
        MBFImage mBFImage = new MBFImage(300, 300, ColourSpace.RGB);
        List supports = matrixKernelPerceptron.getSupports();
        List weights = matrixKernelPerceptron.getWeights();
        double bias = matrixKernelPerceptron.getBias();
        System.out.println("Bias: " + bias);
        double[] dArr = new double[2];
        if (matrixKernelPerceptron instanceof MeanCenteredKernelPerceptron) {
            dArr = ((MeanCenteredKernelPerceptron) matrixKernelPerceptron).getMean();
        } else if (matrixKernelPerceptron instanceof MeanCenteredProjectron) {
            dArr = ((MeanCenteredProjectron) matrixKernelPerceptron).getMean();
        }
        double[] planePoint = LinearVectorKernel.getPlanePoint(supports, weights, bias, new double[]{-dArr[0], Double.NaN});
        double[] planePoint2 = LinearVectorKernel.getPlanePoint(supports, weights, bias, new double[]{mBFImage.getWidth() - dArr[0], Double.NaN});
        planePoint[0] = planePoint[0] + dArr[0];
        planePoint[SEED] = planePoint[SEED] + dArr[SEED];
        planePoint2[0] = planePoint2[0] + dArr[0];
        planePoint2[SEED] = planePoint2[SEED] + dArr[SEED];
        drawLine(mBFImage, planePoint, planePoint2);
    }

    private static void drawLine(MBFImage mBFImage, double[] dArr, double[] dArr2) {
        mBFImage.drawLine(new Line2d(new Point2dImpl((float) dArr[0], (float) dArr[SEED]), new Point2dImpl((float) dArr2[0], (float) dArr2[SEED])), 3, RGBColour.GREEN);
        DisplayUtilities.displayName(mBFImage, "line");
    }

    private static void leanrnPoints(SimplePerceptron simplePerceptron, Iterable<IndependentPair<double[], PerceptronClass>> iterable) {
        int i = 0;
        int i2 = 0;
        for (IndependentPair<double[], PerceptronClass> independentPair : iterable) {
            i2 += SEED;
            double[] dArr = (double[]) independentPair.firstObject();
            int i3 = ((PerceptronClass) independentPair.getSecondObject()) == PerceptronClass.TRUE ? SEED : 0;
            IndependentPair pair = IndependentPair.pair(dArr, Integer.valueOf(i3));
            boolean z = simplePerceptron.predict((double[]) pair.firstObject()) != pair.secondObject();
            simplePerceptron.process(dArr, Integer.valueOf(i3));
            if (z) {
                i += SEED;
            }
            if (i2 % TOTAL_DATA_ITEMS == 0) {
                if (i == 0) {
                    break;
                }
                i2 = 0;
                i = 0;
            }
        }
        drawSpLine(simplePerceptron);
    }

    private static void drawSpLine(SimplePerceptron simplePerceptron) {
        drawLine(new MBFImage(300, 300, ColourSpace.RGB), simplePerceptron.computeHyperplanePoint(new double[]{0.0d, Double.NaN}), simplePerceptron.computeHyperplanePoint(new double[]{r0.getWidth(), Double.NaN}));
    }

    private static void leanrnPoints(MatrixKernelPerceptron matrixKernelPerceptron, Iterable<IndependentPair<double[], PerceptronClass>> iterable) {
        int i = 0;
        int i2 = 0;
        for (IndependentPair<double[], PerceptronClass> independentPair : iterable) {
            i += SEED;
            double[] dArr = (double[]) independentPair.firstObject();
            PerceptronClass perceptronClass = (PerceptronClass) independentPair.getSecondObject();
            int errors = matrixKernelPerceptron.getErrors();
            matrixKernelPerceptron.process(dArr, perceptronClass);
            System.out.println("b: " + matrixKernelPerceptron.getBias() + " w: " + Arrays.toString(LinearVectorKernel.getDirection(matrixKernelPerceptron.getSupports(), matrixKernelPerceptron.getWeights())));
            if (errors != matrixKernelPerceptron.getErrors()) {
                i2 += SEED;
            }
            if (i % TOTAL_DATA_ITEMS == 0) {
                if (i2 == 0) {
                    break;
                }
                i = 0;
                i2 = 0;
            }
        }
        drawMkpLine(matrixKernelPerceptron);
        System.out.println(matrixKernelPerceptron.getSupports().size());
    }

    private static void drawPoints(Stream<IndependentPair<double[], PerceptronClass>> stream, Line2d line2d) {
        MBFImage mBFImage = new MBFImage(300, 300, ColourSpace.RGB);
        mBFImage.drawLine(line2d, 3, RGBColour.BLUE);
        Iterator it = stream.iterator();
        while (it.hasNext()) {
            IndependentPair independentPair = (IndependentPair) it.next();
            double[] dArr = (double[]) independentPair.firstObject();
            Point2dImpl point2dImpl = new Point2dImpl((float) dArr[0], (float) dArr[SEED]);
            switch (AnonymousClass1.$SwitchMap$org$openimaj$ml$linear$learner$perceptron$PerceptronClass[((PerceptronClass) independentPair.getSecondObject()).ordinal()]) {
                case SEED /* 1 */:
                    mBFImage.drawShapeFilled(new Circle(point2dImpl, 5.0f), RGBColour.GREEN);
                    break;
                case 2:
                    mBFImage.drawShape(new Circle(point2dImpl, 5.0f), 3, RGBColour.RED);
                    break;
                case 3:
                    throw new RuntimeException("NOPE");
            }
        }
        DisplayUtilities.displayName(mBFImage, "random");
    }

    private static Point2d end(Vector vector, Vector vector2) {
        Vector add = vector.copy().add(10000.0d, vector2);
        return new Point2dImpl((float) add.get(0), (float) add.get(SEED));
    }

    private static Point2d start(Vector vector, Vector vector2) {
        Vector add = vector.copy().add(-10000.0d, vector2);
        return new Point2dImpl((float) add.get(0), (float) add.get(SEED));
    }
}
