/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.learning.algorithm.ensemble;

import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.factory.Factory;
import gov.sandia.cognition.learning.algorithm.BatchLearner;
import gov.sandia.cognition.learning.algorithm.ensemble.IVotingCategorizerLearner;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.statistics.DataHistogram;
import gov.sandia.cognition.statistics.distribution.MapBasedDataHistogram;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.Random;

public class CategoryBalancedIVotingLearner<InputType, CategoryType>
extends IVotingCategorizerLearner<InputType, CategoryType> {
    public CategoryBalancedIVotingLearner() {
        this(null, 100, 0.1, new Random());
    }

    public CategoryBalancedIVotingLearner(BatchLearner<? super Collection<? extends InputOutputPair<? extends InputType, CategoryType>>, ? extends Evaluator<? super InputType, ? extends CategoryType>> learner, int maxIterations, double percentToSample, Random random) {
        this(learner, maxIterations, percentToSample, 0.5, true, new MapBasedDataHistogram.DefaultFactory(2), random);
    }

    public CategoryBalancedIVotingLearner(BatchLearner<? super Collection<? extends InputOutputPair<? extends InputType, CategoryType>>, ? extends Evaluator<? super InputType, ? extends CategoryType>> learner, int maxIterations, double percentToSample, double proportionIncorrectInSample, boolean voteOutOfBagOnly, Factory<? extends DataHistogram<CategoryType>> counterFactory, Random random) {
        super(learner, maxIterations, percentToSample, proportionIncorrectInSample, voteOutOfBagOnly, counterFactory, random);
    }

    @Override
    protected void createBag(ArrayList<Integer> correctIndices, ArrayList<Integer> incorrectIndices) {
        Object category;
        LinkedHashMap correctPerCategory = new LinkedHashMap();
        LinkedHashMap incorrectPerCategory = new LinkedHashMap();
        for (Object category2 : this.ensemble.getCategories()) {
            correctPerCategory.put(category2, new ArrayList());
            incorrectPerCategory.put(category2, new ArrayList());
        }
        for (Integer index : correctIndices) {
            category = ((InputOutputPair)this.dataList.get(index)).getOutput();
            ((ArrayList)correctPerCategory.get(category)).add(index);
        }
        for (Integer index : incorrectIndices) {
            category = ((InputOutputPair)this.dataList.get(index)).getOutput();
            ((ArrayList)incorrectPerCategory.get(category)).add(index);
        }
        int categoryCount = this.ensemble.getCategories().size();
        int correctPerCategorySize = Math.max(1, this.numCorrectToSample / categoryCount);
        int incorrectPerCategorySize = Math.max(1, this.numIncorrectToSample / categoryCount);
        for (Object category3 : this.ensemble.getCategories()) {
            ArrayList categoryCorrect = (ArrayList)correctPerCategory.get(category3);
            ArrayList categoryIncorrect = (ArrayList)incorrectPerCategory.get(category3);
            if (categoryIncorrect.isEmpty()) {
                categoryIncorrect = categoryCorrect;
            } else if (correctIndices.isEmpty()) {
                categoryCorrect = categoryIncorrect;
            }
            CategoryBalancedIVotingLearner.sampleIndicesWithReplacementInto(categoryCorrect, this.dataList, correctPerCategorySize, this.random, this.currentBag, this.dataInBag);
            CategoryBalancedIVotingLearner.sampleIndicesWithReplacementInto(categoryIncorrect, this.dataList, incorrectPerCategorySize, this.random, this.currentBag, this.dataInBag);
        }
    }
}

