/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.trees;

import java.util.Arrays;
import jsat.classifiers.CategoricalResults;

public class ImpurityScore
implements Cloneable {
    private double sumOfWeights;
    private double[] counts;
    private ImpurityMeasure impurityMeasure;

    public ImpurityScore(int classCount, ImpurityMeasure impurityMeasure) {
        this.sumOfWeights = 0.0;
        this.counts = new double[classCount];
        this.impurityMeasure = impurityMeasure;
    }

    private ImpurityScore(ImpurityScore toClone) {
        this.sumOfWeights = toClone.sumOfWeights;
        this.counts = Arrays.copyOf(toClone.counts, toClone.counts.length);
        this.impurityMeasure = toClone.impurityMeasure;
    }

    public void removePoint(double weight, int targetClass) {
        int n = targetClass;
        this.counts[n] = this.counts[n] - weight;
        this.sumOfWeights -= weight;
    }

    public void addPoint(double weight, int targetClass) {
        int n = targetClass;
        this.counts[n] = this.counts[n] + weight;
        this.sumOfWeights += weight;
    }

    public double getScore() {
        if (this.sumOfWeights <= 0.0) {
            return 0.0;
        }
        double score = 0.0;
        if (this.impurityMeasure == ImpurityMeasure.INFORMATION_GAIN_RATIO || this.impurityMeasure == ImpurityMeasure.INFORMATION_GAIN || this.impurityMeasure == ImpurityMeasure.NMI) {
            double[] dArray = this.counts;
            int n = dArray.length;
            for (int i = 0; i < n; ++i) {
                Double count = dArray[i];
                double p = count / this.sumOfWeights;
                if (!(p > 0.0)) continue;
                score += p * Math.log(p) / Math.log(2.0);
            }
        } else if (this.impurityMeasure == ImpurityMeasure.GINI) {
            score = 1.0;
            for (double count : this.counts) {
                double p = count / this.sumOfWeights;
                score -= p * p;
            }
        } else if (this.impurityMeasure == ImpurityMeasure.CLASSIFICATION_ERROR) {
            double maxClass = 0.0;
            for (double count : this.counts) {
                maxClass = Math.max(maxClass, count / this.sumOfWeights);
            }
            score = 1.0 - maxClass;
        }
        return Math.abs(score);
    }

    public double getSumOfWeights() {
        return this.sumOfWeights;
    }

    public ImpurityMeasure getImpurityMeasure() {
        return this.impurityMeasure;
    }

    public CategoricalResults getResults() {
        CategoricalResults cr = new CategoricalResults(this.counts.length);
        for (int i = 0; i < this.counts.length; ++i) {
            cr.setProb(i, this.counts[i] / this.sumOfWeights);
        }
        return cr;
    }

    public static double gain(ImpurityScore wholeData, ImpurityScore ... splits) {
        return ImpurityScore.gain(wholeData, 1.0, splits);
    }

    public static double gain(ImpurityScore wholeData, double wholeScale, ImpurityScore ... splits) {
        boolean useSplitInfo;
        double sumOfAllSums = wholeScale * wholeData.sumOfWeights;
        if (splits[0].impurityMeasure == ImpurityMeasure.NMI) {
            double mi = 0.0;
            double splitEntropy = 0.0;
            double classEntropy = 0.0;
            for (int c = 0; c < wholeData.counts.length; ++c) {
                double p_c = wholeScale * wholeData.counts[c] / sumOfAllSums;
                if (p_c <= 0.0) continue;
                double logP_c = Math.log(p_c);
                classEntropy += p_c * logP_c;
                for (int s = 0; s < splits.length; ++s) {
                    double p_cs;
                    double p_s = splits[s].sumOfWeights / sumOfAllSums;
                    if (p_s <= 0.0 || (p_cs = splits[s].counts[c] / sumOfAllSums) <= 0.0) continue;
                    mi += p_cs * (Math.log(p_cs) - logP_c - Math.log(p_s));
                    if (c != 0) continue;
                    splitEntropy += p_s * Math.log(p_s);
                }
            }
            splitEntropy = Math.abs(splitEntropy);
            classEntropy = Math.abs(classEntropy);
            return 2.0 * mi / (splitEntropy + classEntropy);
        }
        double splitScore = 0.0;
        boolean bl = useSplitInfo = splits[0].impurityMeasure == ImpurityMeasure.INFORMATION_GAIN_RATIO;
        if (useSplitInfo) {
            double splitInfo = 1.0;
            for (ImpurityScore split : splits) {
                double p = split.getSumOfWeights() / sumOfAllSums;
                if (p <= 0.0) continue;
                splitScore += p * split.getScore();
                splitInfo += p * -Math.log(p);
            }
            return (wholeData.getScore() - splitScore) / splitInfo;
        }
        for (ImpurityScore split : splits) {
            double p = split.getSumOfWeights() / sumOfAllSums;
            if (p <= 0.0) continue;
            splitScore += p * split.getScore();
        }
        return wholeData.getScore() - splitScore;
    }

    protected ImpurityScore clone() {
        return new ImpurityScore(this);
    }

    public static enum ImpurityMeasure {
        INFORMATION_GAIN,
        INFORMATION_GAIN_RATIO,
        NMI,
        GINI,
        CLASSIFICATION_ERROR;

    }
}

