/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.trees;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.Set;
import java.util.Stack;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.trees.ImpurityScore;
import jsat.classifiers.trees.TreeLearner;
import jsat.classifiers.trees.TreeNodeVisitor;
import jsat.math.OnLineStatistics;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.utils.IntList;
import jsat.utils.IntSet;
import jsat.utils.ListUtils;
import jsat.utils.random.RandomUtil;

public class ExtraTree
implements Classifier,
Regressor,
TreeLearner,
Parameterized {
    private static final long serialVersionUID = 7433728970041876327L;
    private int stopSize;
    private int selectionCount;
    protected CategoricalData predicting;
    private boolean binaryCategoricalSplitting = true;
    private int numNumericFeatures;
    private ImpurityScore.ImpurityMeasure impMeasure = ImpurityScore.ImpurityMeasure.NMI;
    private TreeNodeVisitor root;

    public ExtraTree() {
        this(Integer.MAX_VALUE, 5);
    }

    public ExtraTree(int selectionCount, int stopSize) {
        this.stopSize = stopSize;
        this.selectionCount = selectionCount;
        this.impMeasure = ImpurityScore.ImpurityMeasure.NMI;
    }

    public ExtraTree(ExtraTree toCopy) {
        this.stopSize = toCopy.stopSize;
        this.selectionCount = toCopy.selectionCount;
        if (toCopy.predicting != null) {
            this.predicting = toCopy.predicting;
        }
        this.numNumericFeatures = toCopy.numNumericFeatures;
        this.binaryCategoricalSplitting = toCopy.binaryCategoricalSplitting;
        this.impMeasure = toCopy.impMeasure;
        if (toCopy.root != null) {
            this.root = toCopy.root.clone();
        }
    }

    public void setImpurityMeasure(ImpurityScore.ImpurityMeasure impurityMeasure) {
        this.impMeasure = impurityMeasure;
    }

    public ImpurityScore.ImpurityMeasure getImpurityMeasure() {
        return this.impMeasure;
    }

    public void setStopSize(int stopSize) {
        if (stopSize <= 0) {
            throw new ArithmeticException("The stopping size must be a positive value");
        }
        this.stopSize = stopSize;
    }

    public int getStopSize() {
        return this.stopSize;
    }

    public void setSelectionCount(int selectionCount) {
        this.selectionCount = selectionCount;
    }

    public int getSelectionCount() {
        return this.selectionCount;
    }

    public void setBinaryCategoricalSplitting(boolean binaryCategoricalSplitting) {
        this.binaryCategoricalSplitting = binaryCategoricalSplitting;
    }

    public boolean isBinaryCategoricalSplitting() {
        return this.binaryCategoricalSplitting;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        return this.root.classify(data);
    }

    @Override
    public void train(ClassificationDataSet dataSet, boolean parallel) {
        Random rand = RandomUtil.getRandom();
        IntList features = new IntList(dataSet.getNumFeatures());
        ListUtils.addRange(features, 0, dataSet.getNumFeatures(), 1);
        this.predicting = dataSet.getPredicting();
        ImpurityScore score = new ImpurityScore(this.predicting.getNumOfCategories(), this.impMeasure);
        for (int i = 0; i < dataSet.size(); ++i) {
            score.addPoint(dataSet.getWeight(i), dataSet.getDataPointCategory(i));
        }
        this.numNumericFeatures = dataSet.getNumNumericalVars();
        this.root = this.trainC(score, dataSet, features, dataSet.getCategories(), rand);
    }

    private TreeNodeVisitor trainC(ImpurityScore setScore, ClassificationDataSet subSet, List<Integer> features, CategoricalData[] catInfo, Random rand) {
        NodeC toReturn;
        if (subSet.size() < this.stopSize || setScore.getScore() == 0.0) {
            if (subSet.isEmpty()) {
                return null;
            }
            return new NodeC(setScore.getResults());
        }
        double bestGain = Double.NEGATIVE_INFINITY;
        double bestThreshold = Double.NaN;
        int bestAttribute = -1;
        ImpurityScore[] bestScores = null;
        ArrayList<ClassificationDataSet> bestSplit = null;
        IntSet bestLeftSide = null;
        Collections.shuffle(features);
        int goTo = Math.min(this.selectionCount, features.size());
        for (int i = 0; i < goTo; ++i) {
            double gain;
            ArrayList<ClassificationDataSet> aSplit;
            ImpurityScore[] scores;
            double threshold = Double.NaN;
            IntSet leftSide = null;
            int a = features.get(i);
            if (a < catInfo.length) {
                int vals = catInfo[a].getNumOfCategories();
                if (this.binaryCategoricalSplitting || vals == 2) {
                    scores = this.createScores(2);
                    IntSet catsValsInUse = new IntSet(vals * 2);
                    for (int j = 0; j < subSet.size(); ++j) {
                        catsValsInUse.add(Integer.valueOf(subSet.getDataPoint(j).getCategoricalValue(a)));
                    }
                    if (catsValsInUse.size() == 1) {
                        return new NodeC(setScore.getResults());
                    }
                    leftSide = new IntSet(vals);
                    int toUse = rand.nextInt(catsValsInUse.size() - 1) + 1;
                    ListUtils.randomSample(catsValsInUse, leftSide, toUse, rand);
                    aSplit = new ArrayList(2);
                    aSplit.add(subSet.emptyClone());
                    aSplit.add(subSet.emptyClone());
                    for (int j = 0; j < subSet.size(); ++j) {
                        int dest = leftSide.contains((Object)subSet.getDataPoint(j).getCategoricalValue(a)) ? 0 : 1;
                        scores[dest].addPoint(subSet.getWeight(j), subSet.getDataPointCategory(j));
                        ((ClassificationDataSet)aSplit.get(dest)).addDataPoint(subSet.getDataPoint(j), subSet.getDataPointCategory(j), subSet.getWeight(j));
                    }
                } else {
                    scores = this.createScores(vals);
                    aSplit = new ArrayList<ClassificationDataSet>(vals);
                    for (int z = 0; z < vals; ++z) {
                        aSplit.add(subSet.emptyClone());
                    }
                    for (int j = 0; j < subSet.size(); ++j) {
                        DataPoint dp = subSet.getDataPoint(j);
                        int y_j = subSet.getDataPointCategory(j);
                        double w_j = subSet.getWeight(j);
                        scores[dp.getCategoricalValue(a)].addPoint(w_j, y_j);
                        ((ClassificationDataSet)aSplit.get(dp.getCategoricalValue(a))).addDataPoint(dp, y_j, w_j);
                    }
                }
            } else {
                double val;
                int j;
                int numerA = a - catInfo.length;
                double min = Double.POSITIVE_INFINITY;
                double max = Double.NEGATIVE_INFINITY;
                for (j = 0; j < subSet.size(); ++j) {
                    val = subSet.getDataPoint(j).getNumericalValues().get(numerA);
                    min = Math.min(min, val);
                    max = Math.max(max, val);
                }
                threshold = rand.nextDouble() * (max - min) + min;
                scores = this.createScores(2);
                aSplit = new ArrayList(2);
                aSplit.add(subSet.emptyClone());
                aSplit.add(subSet.emptyClone());
                for (j = 0; j < subSet.size(); ++j) {
                    val = subSet.getDataPoint(j).getNumericalValues().get(numerA);
                    double w_j = subSet.getWeight(j);
                    int y_j = subSet.getDataPointCategory(j);
                    int toAddTo = val <= threshold ? 0 : 1;
                    ((ClassificationDataSet)aSplit.get(toAddTo)).addDataPoint(subSet.getDataPoint(j), y_j, w_j);
                    scores[toAddTo].addPoint(w_j, y_j);
                }
            }
            if (!((gain = ImpurityScore.gain(setScore, scores)) > bestGain)) continue;
            bestGain = gain;
            bestAttribute = a;
            bestThreshold = threshold;
            bestScores = scores;
            bestSplit = aSplit;
            bestLeftSide = leftSide;
        }
        if (bestAttribute < 0) {
            return null;
        }
        if (bestAttribute < catInfo.length) {
            if (bestSplit.size() == 2) {
                toReturn = new NodeCCat(bestAttribute, bestLeftSide, setScore.getResults());
            } else {
                toReturn = new NodeCCat(goTo, bestSplit.size(), setScore.getResults());
                features.remove(new Integer(bestAttribute));
            }
        } else {
            toReturn = new NodeCNum(bestAttribute - catInfo.length, bestThreshold, setScore.getResults());
        }
        for (int i = 0; i < toReturn.children.length; ++i) {
            toReturn.children[i] = this.trainC((ImpurityScore)bestScores[i], (ClassificationDataSet)bestSplit.get(i), features, catInfo, rand);
        }
        return toReturn;
    }

    private TreeNodeVisitor train(OnLineStatistics setScore, RegressionDataSet subSet, List<Integer> features, CategoricalData[] catInfo, Random rand) {
        if (subSet.size() < this.stopSize || setScore.getVarance() <= 0.0 || Double.isNaN(setScore.getVarance())) {
            return new NodeR(setScore.getMean());
        }
        double bestGain = Double.NEGATIVE_INFINITY;
        double bestThreshold = Double.NaN;
        int bestAttribute = -1;
        OnLineStatistics[] bestScores = null;
        ArrayList<RegressionDataSet> bestSplit = null;
        IntSet bestLeftSide = null;
        Collections.shuffle(features);
        int goTo = Math.min(this.selectionCount, features.size());
        for (int i = 0; i < goTo; ++i) {
            ArrayList<RegressionDataSet> aSplit;
            OnLineStatistics[] stats;
            double threshold = Double.NaN;
            IntSet leftSide = null;
            int a = features.get(i);
            if (a < catInfo.length) {
                int vals = catInfo[a].getNumOfCategories();
                if (this.binaryCategoricalSplitting || vals == 2) {
                    stats = this.createStats(2);
                    IntSet catsValsInUse = new IntSet(vals * 2);
                    for (int j = 0; j < subSet.size(); ++j) {
                        catsValsInUse.add(Integer.valueOf(subSet.getDataPoint(j).getCategoricalValue(a)));
                    }
                    if (catsValsInUse.size() == 1) {
                        return new NodeR(setScore.getMean());
                    }
                    leftSide = new IntSet(vals);
                    int toUse = rand.nextInt(catsValsInUse.size() - 1) + 1;
                    ListUtils.randomSample(catsValsInUse, leftSide, toUse, rand);
                    aSplit = new ArrayList(2);
                    aSplit.add(subSet.emptyClone());
                    aSplit.add(subSet.emptyClone());
                    for (int j = 0; j < subSet.size(); ++j) {
                        DataPoint dp = subSet.getDataPoint(j);
                        double w_j = subSet.getWeight(j);
                        double y_j = subSet.getTargetValue(j);
                        int dest = leftSide.contains((Object)dp.getCategoricalValue(a)) ? 0 : 1;
                        stats[dest].add(y_j, w_j);
                        ((RegressionDataSet)aSplit.get(dest)).addDataPoint(dp, y_j, w_j);
                    }
                } else {
                    stats = this.createStats(vals);
                    aSplit = new ArrayList<RegressionDataSet>(vals);
                    for (int z = 0; z < vals; ++z) {
                        aSplit.add(subSet.emptyClone());
                    }
                    for (int j = 0; j < subSet.size(); ++j) {
                        DataPoint dp = subSet.getDataPoint(j);
                        double w_j = subSet.getWeight(j);
                        double y_j = subSet.getTargetValue(j);
                        stats[dp.getCategoricalValue(a)].add(y_j, w_j);
                        ((RegressionDataSet)aSplit.get(dp.getCategoricalValue(a))).addDataPoint(dp, y_j, w_j);
                    }
                }
            } else {
                DataPoint dp;
                int j;
                int numerA = a - catInfo.length;
                double min = Double.POSITIVE_INFINITY;
                double max = Double.NEGATIVE_INFINITY;
                for (j = 0; j < subSet.size(); ++j) {
                    dp = subSet.getDataPoint(j);
                    double val = dp.getNumericalValues().get(numerA);
                    min = Math.min(min, val);
                    max = Math.max(max, val);
                }
                threshold = rand.nextDouble() * (max - min) + min;
                stats = this.createStats(2);
                aSplit = new ArrayList(2);
                aSplit.add(subSet.emptyClone());
                aSplit.add(subSet.emptyClone());
                for (j = 0; j < subSet.size(); ++j) {
                    dp = subSet.getDataPoint(j);
                    double w_j = subSet.getWeight(j);
                    double y_j = subSet.getTargetValue(j);
                    double val = dp.getNumericalValues().get(numerA);
                    int toAddTo = val <= threshold ? 0 : 1;
                    ((RegressionDataSet)aSplit.get(toAddTo)).addDataPoint(dp, y_j, w_j);
                    stats[toAddTo].add(y_j, w_j);
                }
            }
            double gain = 1.0;
            double varNorm = setScore.getVarance();
            double varSum = setScore.getSumOfWeights();
            for (OnLineStatistics stat : stats) {
                gain -= stat.getSumOfWeights() / varSum * (stat.getVarance() / varNorm);
            }
            if (!(gain > bestGain)) continue;
            bestGain = gain;
            bestAttribute = a;
            bestThreshold = threshold;
            bestScores = stats;
            bestSplit = aSplit;
            bestLeftSide = leftSide;
        }
        if (bestAttribute >= 0) {
            NodeR toReturn;
            if (bestAttribute < catInfo.length) {
                if (bestSplit.size() == 2) {
                    toReturn = new NodeRCat(bestAttribute, bestLeftSide, setScore.getMean());
                } else {
                    toReturn = new NodeRCat(goTo, bestSplit.size(), setScore.getMean());
                    features.remove(new Integer(bestAttribute));
                }
            } else {
                toReturn = new NodeRNum(bestAttribute - catInfo.length, bestThreshold, setScore.getMean());
            }
            for (int i = 0; i < toReturn.children.length; ++i) {
                toReturn.children[i] = this.train((OnLineStatistics)bestScores[i], (RegressionDataSet)bestSplit.get(i), features, catInfo, rand);
            }
            return toReturn;
        }
        return new NodeR(setScore.getMean());
    }

    @Override
    public boolean supportsWeightedData() {
        return true;
    }

    @Override
    public ExtraTree clone() {
        return new ExtraTree(this);
    }

    @Override
    public TreeNodeVisitor getTreeNodeVisitor() {
        return this.root;
    }

    private static <T> void fillList(int listsToAdd, Stack<List<T>> reusableLists, List<List<T>> aSplit) {
        for (int j = 0; j < listsToAdd; ++j) {
            if (reusableLists.isEmpty()) {
                aSplit.add(new ArrayList());
                continue;
            }
            aSplit.add(reusableLists.pop());
        }
    }

    private ImpurityScore[] createScores(int count) {
        ImpurityScore[] scores = new ImpurityScore[count];
        for (int j = 0; j < scores.length; ++j) {
            scores[j] = new ImpurityScore(this.predicting.getNumOfCategories(), this.impMeasure);
        }
        return scores;
    }

    @Override
    public double regress(DataPoint data) {
        return this.root.regress(data);
    }

    @Override
    public void train(RegressionDataSet dataSet, boolean parallel) {
        this.train(dataSet);
    }

    @Override
    public void train(RegressionDataSet dataSet) {
        Random rand = RandomUtil.getRandom();
        IntList features = new IntList(dataSet.getNumFeatures());
        ListUtils.addRange(features, 0, dataSet.getNumFeatures(), 1);
        OnLineStatistics score = new OnLineStatistics();
        for (int j = 0; j < dataSet.size(); ++j) {
            double w_j = dataSet.getWeight(j);
            double y_j = dataSet.getTargetValue(j);
            score.add(y_j, w_j);
        }
        this.numNumericFeatures = dataSet.getNumNumericalVars();
        this.root = this.train(score, dataSet, features, dataSet.getCategories(), rand);
    }

    private OnLineStatistics[] createStats(int count) {
        OnLineStatistics[] stats = new OnLineStatistics[count];
        for (int i = 0; i < stats.length; ++i) {
            stats[i] = new OnLineStatistics();
        }
        return stats;
    }

    private class NodeRCat
    extends NodeR {
        private static final long serialVersionUID = 5868393594474661054L;
        private int catAtt;
        private int[] leftBranch;

        public NodeRCat(int catAtt, int children, double result) {
            super(result, children);
            this.catAtt = catAtt;
            this.leftBranch = null;
        }

        public NodeRCat(int catAtt, Set<Integer> left, double result) {
            super(result, 2);
            this.catAtt = catAtt;
            this.leftBranch = new int[left.size()];
            int pos = 0;
            for (int i : left) {
                this.leftBranch[pos++] = i;
            }
            Arrays.sort(this.leftBranch);
        }

        public NodeRCat(NodeRCat toClone) {
            super(toClone);
            this.catAtt = toClone.catAtt;
            if (toClone.leftBranch != null) {
                this.leftBranch = Arrays.copyOf(toClone.leftBranch, toClone.leftBranch.length);
            }
        }

        @Override
        public int getPath(DataPoint dp) {
            int[] catVals = dp.getCategoricalValues();
            if (this.leftBranch == null) {
                return catVals[this.catAtt];
            }
            if (Arrays.binarySearch(this.leftBranch, catVals[this.catAtt]) < 0) {
                return 1;
            }
            return 0;
        }

        @Override
        public Collection<Integer> featuresUsed() {
            IntList used = new IntList(1);
            used.add(this.catAtt + ExtraTree.this.numNumericFeatures);
            return used;
        }

        @Override
        public TreeNodeVisitor clone() {
            return new NodeRCat(this);
        }
    }

    private static class NodeRNum
    extends NodeR {
        private static final long serialVersionUID = -6775472771777960211L;
        private int numerAtt;
        private double threshold;

        public NodeRNum(int numerAtt, double threshold, double result) {
            super(result, 2);
            this.numerAtt = numerAtt;
            this.threshold = threshold;
        }

        public NodeRNum(NodeRNum toClone) {
            super(toClone);
            this.numerAtt = toClone.numerAtt;
            this.threshold = toClone.threshold;
        }

        @Override
        public int getPath(DataPoint dp) {
            double val = dp.getNumericalValues().get(this.numerAtt);
            if (val <= this.threshold) {
                return 0;
            }
            return 1;
        }

        @Override
        public TreeNodeVisitor clone() {
            return new NodeRNum(this);
        }

        @Override
        public Collection<Integer> featuresUsed() {
            IntList used = new IntList(1);
            used.add(this.numerAtt);
            return used;
        }
    }

    private static class NodeR
    extends NodeBase {
        private static final long serialVersionUID = -2461046505444129890L;
        private double result;

        public NodeR(double result) {
            this.result = result;
        }

        public NodeR(double result, int children) {
            super(children);
            this.result = result;
        }

        public NodeR(NodeR toClone) {
            super(toClone);
            this.result = toClone.result;
        }

        @Override
        public double localRegress(DataPoint dp) {
            return this.result;
        }

        @Override
        public int getPath(DataPoint dp) {
            return -1;
        }

        @Override
        public TreeNodeVisitor clone() {
            return new NodeR(this);
        }

        @Override
        public Collection<Integer> featuresUsed() {
            return Collections.EMPTY_SET;
        }
    }

    private static abstract class NodeBase
    extends TreeNodeVisitor {
        private static final long serialVersionUID = 6783491817922690901L;
        protected TreeNodeVisitor[] children;

        public NodeBase() {
        }

        public NodeBase(int children) {
            this.children = new TreeNodeVisitor[children];
        }

        public NodeBase(NodeBase toClone) {
            if (toClone.children != null) {
                this.children = new TreeNodeVisitor[toClone.children.length];
                for (int i = 0; i < toClone.children.length; ++i) {
                    if (toClone.children[i] == null) continue;
                    this.children[i] = toClone.children[i].clone();
                }
            }
        }

        @Override
        public int childrenCount() {
            return this.children.length;
        }

        @Override
        public boolean isLeaf() {
            if (this.children == null) {
                return true;
            }
            for (int i = 0; i < this.children.length; ++i) {
                if (this.children[i] == null) continue;
                return false;
            }
            return true;
        }

        @Override
        public TreeNodeVisitor getChild(int child) {
            if (child < 0 || child > this.childrenCount()) {
                return null;
            }
            return this.children[child];
        }

        @Override
        public void disablePath(int child) {
            if (!this.isLeaf()) {
                this.children[child] = null;
            }
        }

        @Override
        public boolean isPathDisabled(int child) {
            if (this.isLeaf()) {
                return true;
            }
            return this.children[child] == null;
        }
    }

    private static class NodeC
    extends NodeBase {
        private static final long serialVersionUID = -3977497656918695759L;
        private CategoricalResults crResult;

        public NodeC(CategoricalResults crResult) {
            this.crResult = crResult;
            this.children = null;
        }

        public NodeC(CategoricalResults crResult, int children) {
            super(children);
            this.crResult = crResult;
        }

        public NodeC(NodeC toClone) {
            super(toClone);
            this.crResult = toClone.crResult.clone();
        }

        @Override
        public CategoricalResults localClassify(DataPoint dp) {
            return this.crResult;
        }

        @Override
        public int getPath(DataPoint dp) {
            return -1;
        }

        @Override
        public TreeNodeVisitor clone() {
            return new NodeC(this);
        }

        @Override
        public Collection<Integer> featuresUsed() {
            return Collections.EMPTY_SET;
        }
    }

    private static class NodeCNum
    extends NodeC {
        private static final long serialVersionUID = 3967180517059509869L;
        private int numerAtt;
        private double threshold;

        public NodeCNum(int numerAtt, double threshold, CategoricalResults crResult) {
            super(crResult, 2);
            this.numerAtt = numerAtt;
            this.threshold = threshold;
        }

        public NodeCNum(NodeCNum toClone) {
            super(toClone);
            this.numerAtt = toClone.numerAtt;
            this.threshold = toClone.threshold;
        }

        @Override
        public int getPath(DataPoint dp) {
            double val = dp.getNumericalValues().get(this.numerAtt);
            if (val <= this.threshold) {
                return 0;
            }
            return 1;
        }

        @Override
        public TreeNodeVisitor clone() {
            return new NodeCNum(this);
        }

        @Override
        public Collection<Integer> featuresUsed() {
            IntList used = new IntList(1);
            used.add(this.numerAtt);
            return used;
        }
    }

    private class NodeCCat
    extends NodeC {
        private static final long serialVersionUID = 7413428280703235600L;
        private int catAtt;
        private int[] leftBranch;

        public NodeCCat(int catAtt, int children, CategoricalResults crResult) {
            super(crResult, children);
            this.catAtt = catAtt;
            this.leftBranch = null;
        }

        public NodeCCat(int catAtt, Set<Integer> left, CategoricalResults crResult) {
            super(crResult, 2);
            this.catAtt = catAtt;
            this.leftBranch = new int[left.size()];
            int pos = 0;
            for (int i : left) {
                this.leftBranch[pos++] = i;
            }
            Arrays.sort(this.leftBranch);
        }

        public NodeCCat(NodeCCat toClone) {
            super(toClone);
            this.catAtt = toClone.catAtt;
            if (toClone.leftBranch != null) {
                this.leftBranch = Arrays.copyOf(toClone.leftBranch, toClone.leftBranch.length);
            }
        }

        @Override
        public int getPath(DataPoint dp) {
            int[] catVals = dp.getCategoricalValues();
            if (this.leftBranch == null) {
                return catVals[this.catAtt];
            }
            if (Arrays.binarySearch(this.leftBranch, catVals[this.catAtt]) < 0) {
                return 1;
            }
            return 0;
        }

        @Override
        public TreeNodeVisitor clone() {
            return new NodeCCat(this);
        }

        @Override
        public Collection<Integer> featuresUsed() {
            IntList used = new IntList(1);
            used.add(this.catAtt + ExtraTree.this.numNumericFeatures);
            return used;
        }
    }
}

