/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.trees;

import java.util.ArrayList;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.DataPoint;
import jsat.classifiers.trees.DecisionStump;
import jsat.classifiers.trees.TreeNodeVisitor;
import jsat.math.SpecialMath;
import jsat.utils.IntList;

public class TreePruner {
    private TreePruner() {
    }

    public static void prune(TreeNodeVisitor root, PruningMethod method, ClassificationDataSet testSet) {
        if (method == PruningMethod.NONE) {
            return;
        }
        if (method == PruningMethod.REDUCED_ERROR) {
            TreePruner.pruneReduceError(null, -1, root, testSet);
        } else if (method == PruningMethod.ERROR_BASED) {
            TreePruner.pruneErrorBased(null, -1, root, testSet, 0.25);
        } else {
            throw new RuntimeException("BUG: please report");
        }
    }

    private static int pruneReduceError(TreeNodeVisitor parent, int pathFollowed, TreeNodeVisitor current, ClassificationDataSet testSet) {
        if (current == null) {
            return 0;
        }
        int nodesPruned = 0;
        if (!current.isLeaf()) {
            int i;
            int numSplits = current.childrenCount();
            ArrayList<ClassificationDataSet> splits = new ArrayList<ClassificationDataSet>(numSplits);
            IntList hadMissing = new IntList();
            double[] fracs = new double[numSplits];
            double wSum = 0.0;
            for (i = 0; i < numSplits; ++i) {
                splits.add(testSet.emptyClone());
            }
            for (i = 0; i < testSet.size(); ++i) {
                double w_i = testSet.getWeight(i);
                int path = current.getPath(testSet.getDataPoint(i));
                if (path >= 0) {
                    ((ClassificationDataSet)splits.get(path)).addDataPoint(testSet.getDataPoint(i), testSet.getDataPointCategory(i), w_i);
                    wSum += w_i;
                    int n = path;
                    fracs[n] = fracs[n] + w_i;
                    continue;
                }
                hadMissing.add(i);
            }
            i = 0;
            while (i < numSplits) {
                int n = i++;
                fracs[n] = fracs[n] / (wSum + 1.0E-15);
            }
            if (!hadMissing.isEmpty()) {
                DecisionStump.distributMissing(splits, fracs, testSet, hadMissing);
            }
            for (i = numSplits - 1; i >= 0; --i) {
                nodesPruned += TreePruner.pruneReduceError(current, i, current.getChild(i), (ClassificationDataSet)splits.get(i));
            }
        }
        if (current.isLeaf() && parent != null) {
            double childCorrect = 0.0;
            double parrentCorrect = 0.0;
            for (int i = 0; i < testSet.size(); ++i) {
                DataPoint dp = testSet.getDataPoint(i);
                int truth = testSet.getDataPointCategory(i);
                if (current.localClassify(dp).mostLikely() == truth) {
                    childCorrect += testSet.getWeight(i);
                }
                if (parent.localClassify(dp).mostLikely() != truth) continue;
                parrentCorrect += testSet.getWeight(i);
            }
            if (parrentCorrect >= childCorrect) {
                parent.disablePath(pathFollowed);
                return nodesPruned + 1;
            }
            return nodesPruned;
        }
        return nodesPruned;
    }

    private static double pruneErrorBased(TreeNodeVisitor parent, int pathFollowed, TreeNodeVisitor current, ClassificationDataSet testSet, double alpha) {
        double maxChildTreeScore;
        int i;
        if (current == null || testSet.isEmpty()) {
            return 0.0;
        }
        if (current.isLeaf()) {
            int errors = 0;
            double N = 0.0;
            for (int i2 = 0; i2 < testSet.size(); ++i2) {
                if (current.localClassify(testSet.getDataPoint(i2)).mostLikely() != testSet.getDataPointCategory(i2)) {
                    errors = (int)((double)errors + testSet.getWeight(i2));
                }
                N += testSet.getWeight(i2);
            }
            return TreePruner.computeBinomialUpperBound(N, alpha, errors);
        }
        ArrayList<ClassificationDataSet> splitSet = new ArrayList<ClassificationDataSet>(current.childrenCount());
        IntList hadMissing = new IntList();
        for (int i3 = 0; i3 < current.childrenCount(); ++i3) {
            splitSet.add(testSet.emptyClone());
        }
        int localErrors = 0;
        double subTreeScore = 0.0;
        double N = 0.0;
        double N_missing = 0.0;
        double[] fracs = new double[splitSet.size()];
        for (i = 0; i < testSet.size(); ++i) {
            int path;
            DataPoint dp = testSet.getDataPoint(i);
            int y_i = testSet.getDataPointCategory(i);
            double w_i = testSet.getWeight(i);
            if (current.localClassify(dp).mostLikely() != y_i) {
                localErrors = (int)((double)localErrors + w_i);
            }
            if ((path = current.getPath(dp)) >= 0) {
                N += w_i;
                ((ClassificationDataSet)splitSet.get(path)).addDataPoint(dp, y_i, w_i);
                int n = path;
                fracs[n] = fracs[n] + w_i;
                continue;
            }
            hadMissing.add(i);
            N_missing += w_i;
        }
        i = 0;
        while (i < fracs.length) {
            int n = i++;
            fracs[n] = fracs[n] / N;
        }
        if (!hadMissing.isEmpty()) {
            DecisionStump.distributMissing(splitSet, fracs, testSet, hadMissing);
        }
        int maxChildCount = 0;
        int maxChild = -1;
        for (int path = 0; path < splitSet.size(); ++path) {
            if (current.isPathDisabled(path)) continue;
            subTreeScore += TreePruner.pruneErrorBased(current, path, current.getChild(path), (ClassificationDataSet)splitSet.get(path), alpha);
            if (maxChildCount >= ((ClassificationDataSet)splitSet.get(path)).size()) continue;
            maxChildCount = ((ClassificationDataSet)splitSet.get(path)).size();
            maxChild = path;
        }
        double prunedTreeScore = TreePruner.computeBinomialUpperBound(N + N_missing, alpha, localErrors);
        if (maxChild == -1) {
            maxChildTreeScore = Double.POSITIVE_INFINITY;
        } else {
            TreeNodeVisitor maxChildNode = current.getChild(maxChild);
            int otherE = 0;
            for (int path = 0; path < splitSet.size(); ++path) {
                ClassificationDataSet split = (ClassificationDataSet)splitSet.get(path);
                for (int i4 = 0; i4 < split.size(); ++i4) {
                    if (maxChildNode.classify(split.getDataPoint(i4)).mostLikely() == split.getDataPointCategory(i4)) continue;
                    otherE = (int)((double)otherE + split.getWeight(i4));
                }
            }
            maxChildTreeScore = TreePruner.computeBinomialUpperBound(N + N_missing, alpha, otherE);
        }
        if (maxChildTreeScore < prunedTreeScore && maxChildTreeScore < subTreeScore && parent != null) {
            try {
                parent.setPath(pathFollowed, current.getChild(maxChild));
                return maxChildTreeScore;
            }
            catch (UnsupportedOperationException maxChildNode) {
                // empty catch block
            }
        }
        if (prunedTreeScore < subTreeScore) {
            for (int i5 = 0; i5 < current.childrenCount(); ++i5) {
                current.disablePath(i5);
            }
            return prunedTreeScore;
        }
        return subTreeScore;
    }

    private static double computeBinomialUpperBound(double N, double alpha, double errors) {
        return N * (1.0 - SpecialMath.invBetaIncReg(alpha, N - errors + 1.0E-9, errors + 1.0));
    }

    public static enum PruningMethod {
        NONE,
        REDUCED_ERROR,
        ERROR_BASED;

    }
}

