/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.trees;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.ListIterator;
import jsat.DataSet;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.DataPoint;
import jsat.classifiers.DataPointPair;
import jsat.classifiers.trees.ImpurityScore;
import jsat.classifiers.trees.TreeFeatureImportanceInference;
import jsat.classifiers.trees.TreeLearner;
import jsat.classifiers.trees.TreeNodeVisitor;
import jsat.utils.IntList;

public class MDI
implements TreeFeatureImportanceInference {
    private ImpurityScore.ImpurityMeasure im;

    public MDI(ImpurityScore.ImpurityMeasure im) {
        this.im = im;
    }

    public MDI() {
        this(ImpurityScore.ImpurityMeasure.GINI);
    }

    @Override
    public <Type extends DataSet> double[] getImportanceStats(TreeLearner model, DataSet<Type> data) {
        double[] features = new double[data.getNumFeatures()];
        if (!(data instanceof ClassificationDataSet)) {
            throw new RuntimeException("MDI currently only supports classification datasets");
        }
        List<DataPointPair<Integer>> allData = ((ClassificationDataSet)data).getAsDPPList();
        int K = ((ClassificationDataSet)data).getClassSize();
        ImpurityScore score = new ImpurityScore(K, this.im);
        for (int i = 0; i < data.size(); ++i) {
            score.addPoint(data.getWeight(i), ((ClassificationDataSet)data).getDataPointCategory(i));
        }
        this.visit(model.getTreeNodeVisitor(), score, (ClassificationDataSet)data, IntList.range(data.size()), features, score.getSumOfWeights(), K);
        return features;
    }

    /*
     * WARNING - void declaration
     */
    private void visit(TreeNodeVisitor node, ImpurityScore score, ClassificationDataSet data, IntList subset, double[] features, double N, int K) {
        void var18_23;
        if (node == null || node.isLeaf()) {
            return;
        }
        double curScore = score.getScore();
        double curN = score.getSumOfWeights();
        ArrayList<IntList> splitsData = new ArrayList<IntList>(node.childrenCount());
        ArrayList<ImpurityScore> splitScores = new ArrayList<ImpurityScore>(node.childrenCount());
        splitsData.add(subset);
        splitScores.add(score);
        for (int i = 0; i < node.childrenCount() - 1; ++i) {
            splitsData.add(new IntList());
            splitScores.add(new ImpurityScore(K, this.im));
        }
        ListIterator iter = subset.listIterator();
        while (iter.hasNext()) {
            int indx = (Integer)iter.next();
            int tc = data.getDataPointCategory(indx);
            DataPoint dataPoint = data.getDataPoint(indx);
            double w = data.getWeight(indx);
            int path = node.getPath(dataPoint);
            if (path < 0) {
                score.removePoint(w, tc);
                continue;
            }
            if (path <= 0) continue;
            score.removePoint(w, tc);
            ((ImpurityScore)splitScores.get(path)).addPoint(w, tc);
            ((IntList)splitsData.get(path)).add(indx);
            iter.remove();
        }
        double chageInImp = curScore;
        for (ImpurityScore impurityScore : splitScores) {
            chageInImp -= impurityScore.getScore() * (impurityScore.getSumOfWeights() / (1.0E-5 + curN));
        }
        Collection<Integer> featuresUsed = node.featuresUsed();
        Iterator<Integer> iterator = featuresUsed.iterator();
        while (iterator.hasNext()) {
            int feature;
            int n = feature = iterator.next().intValue();
            features[n] = features[n] + chageInImp * curN / N;
        }
        boolean bl = false;
        while (var18_23 < splitScores.size()) {
            this.visit(node.getChild((int)var18_23), (ImpurityScore)splitScores.get((int)var18_23), data, (IntList)splitsData.get((int)var18_23), features, N, K);
            ++var18_23;
        }
    }
}

