/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.trees;

import java.util.Random;
import java.util.Set;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.trees.DecisionTree;
import jsat.classifiers.trees.TreePruner;
import jsat.regression.RegressionDataSet;
import jsat.utils.ModifiableCountDownLatch;
import jsat.utils.random.RandomUtil;

public class RandomDecisionTree
extends DecisionTree {
    private static final long serialVersionUID = -809244056947507494L;
    private int numFeatures;

    public RandomDecisionTree() {
        this(1);
    }

    public RandomDecisionTree(int numFeatures) {
        this.setRandomFeatureCount(numFeatures);
    }

    public RandomDecisionTree(int numFeatures, int maxDepth, int minSamples, TreePruner.PruningMethod pruningMethod, double testProportion) {
        super(maxDepth, minSamples, pruningMethod, testProportion);
        this.setRandomFeatureCount(numFeatures);
    }

    public RandomDecisionTree(RandomDecisionTree toCopy) {
        super(toCopy);
        this.numFeatures = toCopy.numFeatures;
    }

    public void setRandomFeatureCount(int numFeatures) {
        if (numFeatures < 1) {
            throw new IllegalArgumentException("Number of features must be positive, not " + numFeatures);
        }
        this.numFeatures = numFeatures;
    }

    public int getRandomFeatureCount() {
        return this.numFeatures;
    }

    @Override
    protected DecisionTree.Node makeNodeC(ClassificationDataSet dataPoints, Set<Integer> options, int depth, boolean parallel, ModifiableCountDownLatch mcdl) {
        if (dataPoints.isEmpty()) {
            mcdl.countDown();
            return null;
        }
        int featureCount = dataPoints.getNumFeatures();
        this.fillWithRandomFeatures(options, featureCount);
        return super.makeNodeC(dataPoints, options, depth, parallel, mcdl);
    }

    @Override
    protected DecisionTree.Node makeNodeR(RegressionDataSet dataPoints, Set<Integer> options, int depth, boolean parallel, ModifiableCountDownLatch mcdl) {
        if (dataPoints.isEmpty()) {
            mcdl.countDown();
            return null;
        }
        int featureCount = dataPoints.getNumFeatures();
        this.fillWithRandomFeatures(options, featureCount);
        return super.makeNodeR(dataPoints, options, depth, parallel, mcdl);
    }

    private void fillWithRandomFeatures(Set<Integer> options, int featureCount) {
        options.clear();
        Random rand = RandomUtil.getRandom();
        while (options.size() < this.numFeatures) {
            options.add(rand.nextInt(featureCount));
        }
    }

    @Override
    public RandomDecisionTree clone() {
        return new RandomDecisionTree(this);
    }
}

