package org.openimaj.experiment.validation.cross;

import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.openimaj.data.RandomData;
import org.openimaj.data.dataset.ListBackedDataset;
import org.openimaj.data.dataset.ListDataset;
import org.openimaj.experiment.dataset.util.DatasetAdaptors;
import org.openimaj.experiment.validation.DefaultValidationData;
import org.openimaj.experiment.validation.ValidationData;
import org.openimaj.util.list.AcceptingListView;
import org.openimaj.util.list.SkippingListView;

/* loaded from: input_file:org/openimaj/experiment/validation/cross/KFold.class */
public class KFold<INSTANCE> implements CrossValidator<ListDataset<INSTANCE>> {
    private int k;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/openimaj/experiment/validation/cross/KFold$KFoldIterable.class */
    public class KFoldIterable implements CrossValidationIterable<ListDataset<INSTANCE>> {
        private List<INSTANCE> listView;
        private int[][] subsetIndices;

        /* JADX WARN: Type inference failed for: r1v7, types: [int[], int[][]] */
        public KFoldIterable(ListDataset<INSTANCE> listDataset, int i) {
            if (i > listDataset.size()) {
                throw new IllegalArgumentException("The number of folds must be less than the number of items in the dataset");
            }
            if (i <= 0) {
                throw new IllegalArgumentException("The number of folds must be at least one");
            }
            this.listView = DatasetAdaptors.asList(listDataset);
            int[] uniqueRandomInts = RandomData.getUniqueRandomInts(listDataset.size(), 0, listDataset.size());
            this.subsetIndices = new int[i];
            int size = listDataset.size() / i;
            for (int i2 = 0; i2 < i - 1; i2++) {
                this.subsetIndices[i2] = Arrays.copyOfRange(uniqueRandomInts, size * i2, size * (i2 + 1));
            }
            this.subsetIndices[i - 1] = Arrays.copyOfRange(uniqueRandomInts, size * (i - 1), uniqueRandomInts.length);
        }

        @Override // org.openimaj.experiment.validation.cross.CrossValidationIterable
        public int numberIterations() {
            return this.subsetIndices.length;
        }

        @Override // java.lang.Iterable
        public Iterator<ValidationData<ListDataset<INSTANCE>>> iterator() {
            return new Iterator<ValidationData<ListDataset<INSTANCE>>>() { // from class: org.openimaj.experiment.validation.cross.KFold.KFoldIterable.1
                int validationSubset = 0;

                @Override // java.util.Iterator
                public boolean hasNext() {
                    return this.validationSubset < KFoldIterable.this.subsetIndices.length;
                }

                @Override // java.util.Iterator
                public ValidationData<ListDataset<INSTANCE>> next() {
                    ListBackedDataset listBackedDataset = new ListBackedDataset(new SkippingListView(KFoldIterable.this.listView, KFoldIterable.this.subsetIndices[this.validationSubset]));
                    ListBackedDataset listBackedDataset2 = new ListBackedDataset(new AcceptingListView(KFoldIterable.this.listView, KFoldIterable.this.subsetIndices[this.validationSubset]));
                    this.validationSubset++;
                    return new DefaultValidationData(listBackedDataset, listBackedDataset2);
                }

                @Override // java.util.Iterator
                public void remove() {
                    throw new UnsupportedOperationException();
                }
            };
        }
    }

    public KFold(int i) {
        this.k = i;
    }

    @Override // org.openimaj.experiment.validation.cross.CrossValidator
    public CrossValidationIterable<ListDataset<INSTANCE>> createIterable(ListDataset<INSTANCE> listDataset) {
        return new KFoldIterable(listDataset, this.k);
    }

    public String toString() {
        return this.k + "-Fold Cross-Validation";
    }
}
