package org.openimaj.tools.faces.recognition;

import java.io.File;
import java.io.IOException;
import org.kohsuke.args4j.CmdLineException;
import org.kohsuke.args4j.CmdLineParser;
import org.kohsuke.args4j.Option;
import org.kohsuke.args4j.ProxyOptionHandler;
import org.openimaj.data.dataset.GroupedDataset;
import org.openimaj.data.dataset.ListDataset;
import org.openimaj.experiment.ExperimentRunner;
import org.openimaj.experiment.validation.cross.StratifiedGroupedKFold;
import org.openimaj.image.FImage;
import org.openimaj.image.processing.face.detection.DetectedFace;
import org.openimaj.image.processing.face.recognition.FaceRecogniser;
import org.openimaj.image.processing.face.recognition.FaceRecognitionEngine;
import org.openimaj.image.processing.face.recognition.benchmarking.CrossValidationBenchmark;
import org.openimaj.image.processing.face.recognition.benchmarking.FaceRecogniserProvider;
import org.openimaj.image.processing.face.recognition.benchmarking.dataset.TextFileDataset;
import org.openimaj.tools.faces.recognition.options.RecognitionEngineProvider;
import org.openimaj.tools.faces.recognition.options.RecognitionStrategy;

/* loaded from: input_file:org/openimaj/tools/faces/recognition/FaceRecognitionCrossValidatorTool.class */
public class FaceRecognitionCrossValidatorTool<FACE extends DetectedFace> {
    RecognitionEngineProvider<FACE> strategyOp;

    @Option(name = "--dataset", usage = "File formatted as each line being: IDENTIFIER,img", required = true)
    File datasetFile;

    @Option(name = "--save-recogniser", usage = "After cross-validation, create a recogniser using all folds and save it to the given file", required = false)
    File savedRecogniser;

    @Option(name = "--strategy", usage = "Recognition strategy", required = false, handler = ProxyOptionHandler.class)
    RecognitionStrategy strategy = RecognitionStrategy.EigenFaces_KNN;

    @Option(name = "--num-folds", usage = "number of cross-validation folds", required = false)
    int numFolds = 10;

    protected void performBenchmark() throws IOException {
        final FaceRecognitionEngine<FACE, String> createRecognitionEngine = this.strategyOp.createRecognitionEngine();
        System.out.println(ExperimentRunner.runExperiment(new CrossValidationBenchmark(new StratifiedGroupedKFold(10), getDataset(), createRecognitionEngine.getDetector(), new FaceRecogniserProvider<FACE, String>() { // from class: org.openimaj.tools.faces.recognition.FaceRecognitionCrossValidatorTool.1
            public FaceRecogniser<FACE, String> create(GroupedDataset<String, ? extends ListDataset<FACE>, FACE> groupedDataset) {
                FaceRecogniser<FACE, String> recogniser = FaceRecognitionCrossValidatorTool.this.strategyOp.createRecognitionEngine().getRecogniser();
                recogniser.train(groupedDataset);
                return recogniser;
            }

            public String toString() {
                return createRecognitionEngine.getRecogniser().toString();
            }
        })));
    }

    protected void saveRecogniser() throws IOException {
        FaceRecognitionEngine<FACE, String> createRecognitionEngine = this.strategyOp.createRecognitionEngine();
        createRecognitionEngine.train(getDataset());
        createRecognitionEngine.save(this.savedRecogniser);
    }

    private GroupedDataset<String, ListDataset<FImage>, FImage> getDataset() throws IOException {
        return new TextFileDataset(this.datasetFile);
    }

    public static <FACE extends DetectedFace> void main(String[] strArr) throws IOException {
        FaceRecognitionCrossValidatorTool faceRecognitionCrossValidatorTool = new FaceRecognitionCrossValidatorTool();
        CmdLineParser cmdLineParser = new CmdLineParser(faceRecognitionCrossValidatorTool);
        try {
            cmdLineParser.parseArgument(strArr);
            faceRecognitionCrossValidatorTool.performBenchmark();
            if (faceRecognitionCrossValidatorTool.savedRecogniser != null) {
                faceRecognitionCrossValidatorTool.saveRecogniser();
            }
        } catch (CmdLineException e) {
            System.err.println(e.getMessage());
            System.err.println("java FaceRecognitionCrossValidator options...");
            cmdLineParser.printUsage(System.err);
            System.err.println();
            System.err.println("Strategy information:");
            for (RecognitionStrategy recognitionStrategy : RecognitionStrategy.values()) {
                System.err.println(recognitionStrategy);
            }
        }
    }
}
