package gov.sandia.cognition.learning.algorithm.ensemble;

import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.factory.Factory;
import gov.sandia.cognition.learning.algorithm.BatchLearner;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.statistics.DataHistogram;
import gov.sandia.cognition.statistics.distribution.MapBasedDataHistogram;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Random;

/* loaded from: input_file:gov/sandia/cognition/learning/algorithm/ensemble/CategoryBalancedIVotingLearner.class */
public class CategoryBalancedIVotingLearner<InputType, CategoryType> extends IVotingCategorizerLearner<InputType, CategoryType> {
    public CategoryBalancedIVotingLearner() {
        this(null, 100, 0.1d, new Random());
    }

    public CategoryBalancedIVotingLearner(BatchLearner<? super Collection<? extends InputOutputPair<? extends InputType, CategoryType>>, ? extends Evaluator<? super InputType, ? extends CategoryType>> batchLearner, int i, double d, Random random) {
        this(batchLearner, i, d, 0.5d, true, new MapBasedDataHistogram.DefaultFactory(2), random);
    }

    public CategoryBalancedIVotingLearner(BatchLearner<? super Collection<? extends InputOutputPair<? extends InputType, CategoryType>>, ? extends Evaluator<? super InputType, ? extends CategoryType>> batchLearner, int i, double d, double d2, boolean z, Factory<? extends DataHistogram<CategoryType>> factory, Random random) {
        super(batchLearner, i, d, d2, z, factory, random);
    }

    @Override // gov.sandia.cognition.learning.algorithm.ensemble.IVotingCategorizerLearner
    protected void createBag(ArrayList<Integer> arrayList, ArrayList<Integer> arrayList2) {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        LinkedHashMap linkedHashMap2 = new LinkedHashMap();
        for (CategoryType categorytype : this.ensemble.getCategories()) {
            linkedHashMap.put(categorytype, new ArrayList());
            linkedHashMap2.put(categorytype, new ArrayList());
        }
        Iterator<Integer> it = arrayList.iterator();
        while (it.hasNext()) {
            Integer next = it.next();
            ((ArrayList) linkedHashMap.get(this.dataList.get(next.intValue()).getOutput())).add(next);
        }
        Iterator<Integer> it2 = arrayList2.iterator();
        while (it2.hasNext()) {
            Integer next2 = it2.next();
            ((ArrayList) linkedHashMap2.get(this.dataList.get(next2.intValue()).getOutput())).add(next2);
        }
        int size = this.ensemble.getCategories().size();
        int max = Math.max(1, this.numCorrectToSample / size);
        int max2 = Math.max(1, this.numIncorrectToSample / size);
        for (CategoryType categorytype2 : this.ensemble.getCategories()) {
            ArrayList arrayList3 = (ArrayList) linkedHashMap.get(categorytype2);
            ArrayList arrayList4 = (ArrayList) linkedHashMap2.get(categorytype2);
            if (arrayList4.isEmpty()) {
                arrayList4 = arrayList3;
            } else if (arrayList.isEmpty()) {
                arrayList3 = arrayList4;
            }
            sampleIndicesWithReplacementInto(arrayList3, this.dataList, max, this.random, this.currentBag, this.dataInBag);
            sampleIndicesWithReplacementInto(arrayList4, this.dataList, max2, this.random, this.currentBag, this.dataInBag);
        }
    }
}
