package org.openimaj.experiment.validation.cross;

import gnu.trove.list.array.TIntArrayList;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.openimaj.data.RandomData;
import org.openimaj.experiment.dataset.GroupedDataset;
import org.openimaj.experiment.dataset.ListBackedDataset;
import org.openimaj.experiment.dataset.ListDataset;
import org.openimaj.experiment.dataset.MapBackedDataset;
import org.openimaj.experiment.dataset.util.DatasetAdaptors;
import org.openimaj.experiment.validation.DefaultValidationData;
import org.openimaj.experiment.validation.ValidationData;
import org.openimaj.util.list.AcceptingListView;
import org.openimaj.util.list.SkippingListView;
import org.openimaj.util.pair.IntObjectPair;

/* loaded from: input_file:org/openimaj/experiment/validation/cross/GroupedKFold.class */
public class GroupedKFold<KEY, INSTANCE> implements CrossValidator<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>> {
    private int k;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/openimaj/experiment/validation/cross/GroupedKFold$GroupedKFoldIterable.class */
    public class GroupedKFoldIterable implements CrossValidationIterable<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>> {
        private GroupedDataset<KEY, ? extends ListDataset<INSTANCE>, INSTANCE> dataset;
        private Map<KEY, int[][]> subsetIndices = new HashMap();
        private int numFolds;

        /* JADX WARN: Multi-variable type inference failed */
        public GroupedKFoldIterable(GroupedDataset<KEY, ? extends ListDataset<INSTANCE>, INSTANCE> groupedDataset, int i) {
            if (i > groupedDataset.size()) {
                throw new IllegalArgumentException("The number of folds must be less than the number of items in the dataset");
            }
            if (i <= 0) {
                throw new IllegalArgumentException("The number of folds must be at least one");
            }
            this.dataset = groupedDataset;
            this.numFolds = i;
            int[] uniqueRandomInts = RandomData.getUniqueRandomInts(groupedDataset.size(), 0, groupedDataset.size());
            int[] iArr = new int[i];
            int size = groupedDataset.size() / i;
            for (int i2 = 0; i2 < i - 1; i2++) {
                iArr[i2] = Arrays.copyOfRange(uniqueRandomInts, size * i2, size * (i2 + 1));
            }
            iArr[i - 1] = Arrays.copyOfRange(uniqueRandomInts, size * (i - 1), uniqueRandomInts.length);
            ArrayList<KEY> arrayList = new ArrayList<>(groupedDataset.getGroups());
            Iterator<KEY> it = arrayList.iterator();
            while (it.hasNext()) {
                this.subsetIndices.put(it.next(), new int[i]);
            }
            for (int i3 = 0; i3 < iArr.length; i3++) {
                HashMap hashMap = new HashMap();
                for (char c : iArr[i3]) {
                    IntObjectPair<KEY> computeIndex = computeIndex(arrayList, c);
                    TIntArrayList tIntArrayList = (TIntArrayList) hashMap.get(computeIndex.second);
                    if (tIntArrayList == null) {
                        Object obj = computeIndex.second;
                        TIntArrayList tIntArrayList2 = new TIntArrayList();
                        tIntArrayList = tIntArrayList2;
                        hashMap.put(obj, tIntArrayList2);
                    }
                    tIntArrayList.add(computeIndex.first);
                }
                for (Map.Entry entry : hashMap.entrySet()) {
                    this.subsetIndices.get(entry.getKey())[i3] = ((TIntArrayList) entry.getValue()).toArray();
                }
            }
        }

        private IntObjectPair<KEY> computeIndex(ArrayList<KEY> arrayList, int i) {
            int i2 = 0;
            Iterator<KEY> it = arrayList.iterator();
            while (it.hasNext()) {
                KEY next = it.next();
                int size = this.dataset.getInstances(next).size();
                if (i2 + size > i) {
                    return new IntObjectPair<>(i - i2, next);
                }
                i2 += size;
            }
            throw new RuntimeException("Index not found");
        }

        @Override // org.openimaj.experiment.validation.cross.CrossValidationIterable
        public int numberIterations() {
            return this.numFolds;
        }

        @Override // java.lang.Iterable
        public Iterator<ValidationData<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>>> iterator() {
            return new Iterator<ValidationData<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>>>() { // from class: org.openimaj.experiment.validation.cross.GroupedKFold.GroupedKFoldIterable.1
                int validationSubset = 0;

                @Override // java.util.Iterator
                public boolean hasNext() {
                    return this.validationSubset < GroupedKFoldIterable.this.numFolds;
                }

                /* JADX WARN: Multi-variable type inference failed */
                @Override // java.util.Iterator
                public ValidationData<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>> next() {
                    HashMap hashMap = new HashMap();
                    HashMap hashMap2 = new HashMap();
                    for (Object obj : GroupedKFoldIterable.this.subsetIndices.keySet()) {
                        int[][] iArr = (int[][]) GroupedKFoldIterable.this.subsetIndices.get(obj);
                        List asList = DatasetAdaptors.asList((ListDataset) GroupedKFoldIterable.this.dataset.getInstances(obj));
                        hashMap.put(obj, new ListBackedDataset((List) new SkippingListView(asList, iArr[this.validationSubset])));
                        hashMap2.put(obj, new ListBackedDataset((List) new AcceptingListView(asList, iArr[this.validationSubset])));
                    }
                    MapBackedDataset mapBackedDataset = new MapBackedDataset(hashMap);
                    MapBackedDataset mapBackedDataset2 = new MapBackedDataset(hashMap2);
                    this.validationSubset++;
                    return new DefaultValidationData(mapBackedDataset, mapBackedDataset2);
                }

                @Override // java.util.Iterator
                public void remove() {
                    throw new UnsupportedOperationException();
                }
            };
        }
    }

    public GroupedKFold(int i) {
        this.k = i;
    }

    @Override // org.openimaj.experiment.validation.cross.CrossValidator
    public CrossValidationIterable<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>> createIterable(GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE> groupedDataset) {
        return new GroupedKFoldIterable(groupedDataset, this.k);
    }

    public String toString() {
        return this.k + "-Fold Cross-Validation for grouped datasets";
    }
}
