/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.boosting;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import jsat.DataSet;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.calibration.BinaryScoreClassifier;
import jsat.classifiers.trees.DecisionTree;
import jsat.distributions.Distribution;
import jsat.distributions.Uniform;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.DoubleList;

public class EmphasisBoost
implements Classifier,
Parameterized,
BinaryScoreClassifier {
    private static final long serialVersionUID = -6372897830449685891L;
    @Parameter.ParameterHolder
    private Classifier weakLearner;
    private int maxIterations;
    protected List<Classifier> hypoths;
    protected List<Double> hypWeights;
    protected CategoricalData predicting;
    private double lambda;

    public EmphasisBoost() {
        this(new DecisionTree(6), 200, 0.35);
    }

    public EmphasisBoost(Classifier weakLearner, int maxIterations, double lambda) {
        this.setWeakLearner(weakLearner);
        this.setMaxIterations(maxIterations);
        this.setLambda(lambda);
    }

    protected EmphasisBoost(EmphasisBoost toClone) {
        this(toClone.weakLearner.clone(), toClone.maxIterations, toClone.lambda);
        if (toClone.hypWeights != null) {
            this.hypWeights = new DoubleList(toClone.hypWeights);
            this.hypoths = new ArrayList<Classifier>(toClone.maxIterations);
            for (Classifier weak : toClone.hypoths) {
                this.hypoths.add(weak.clone());
            }
            this.predicting = toClone.predicting.clone();
        }
    }

    public List<Classifier> getModels() {
        return Collections.unmodifiableList(this.hypoths);
    }

    public List<Double> getModelWeights() {
        return Collections.unmodifiableList(this.hypWeights);
    }

    public int getMaxIterations() {
        return this.maxIterations;
    }

    public void setMaxIterations(int maxIterations) {
        if (maxIterations < 1) {
            throw new IllegalArgumentException("Iterations must be positive, not " + maxIterations);
        }
        this.maxIterations = maxIterations;
    }

    public Classifier getWeakLearner() {
        return this.weakLearner;
    }

    public void setWeakLearner(Classifier weakLearner) {
        if (!weakLearner.supportsWeightedData()) {
            throw new IllegalArgumentException("WeakLearner must support weighted data to be boosted");
        }
        this.weakLearner = weakLearner;
    }

    public static Distribution guessLambda(DataSet d) {
        return new Uniform(0.25, 0.45);
    }

    public void setLambda(double lambda) {
        this.lambda = lambda;
    }

    public double getLambda() {
        return this.lambda;
    }

    @Override
    public double getScore(DataPoint dp) {
        double score = 0.0;
        for (int i = 0; i < this.hypoths.size(); ++i) {
            score += this.H(this.hypoths.get(i), dp) * this.hypWeights.get(i);
        }
        return score;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        if (this.predicting == null) {
            throw new RuntimeException("Classifier has not been trained yet");
        }
        CategoricalResults cr = new CategoricalResults(this.predicting.getNumOfCategories());
        double score = this.getScore(data);
        if (score < 0.0) {
            cr.setProb(0, 1.0);
        } else {
            cr.setProb(1, 1.0);
        }
        return cr;
    }

    private double H(Classifier weak, DataPoint dp) {
        CategoricalResults catResult = weak.classify(dp);
        return catResult.getProb(1) * 2.0 - 1.0;
    }

    @Override
    public void train(ClassificationDataSet dataSet, boolean parallel) {
        this.predicting = dataSet.getPredicting();
        this.hypWeights = new DoubleList(this.maxIterations);
        this.hypoths = new ArrayList<Classifier>(this.maxIterations);
        int N = dataSet.size();
        ClassificationDataSet cds = dataSet.shallowClone();
        for (int i = 0; i < cds.size(); ++i) {
            cds.setWeight(i, 1.0 / (double)N);
        }
        double weightSum = 1.0;
        double[] H_cur = new double[N];
        double[] curH_Result = new double[N];
        for (int t = 0; t < this.maxIterations; ++t) {
            int i;
            Classifier weak = this.weakLearner.clone();
            weak.train(cds, parallel);
            double error = 0.0;
            for (int i2 = 0; i2 < cds.size(); ++i2) {
                DataPoint dp = cds.getDataPoint(i2);
                double y_hat = H_cur[i2] = this.H(weak, dp);
                double y_true = cds.getDataPointCategory(i2) * 2 - 1;
                error += cds.getWeight(i2) * y_hat * y_true;
            }
            if (error < 0.0) {
                return;
            }
            double alpha_m = Math.log((1.0 + error) / (1.0 - error)) / 2.0;
            weightSum = 0.0;
            for (i = 0; i < cds.size(); ++i) {
                int n = i;
                curH_Result[n] = curH_Result[n] + alpha_m * H_cur[i];
                double f_t = curH_Result[i];
                DataPoint dp = cds.getDataPoint(i);
                double y_true = cds.getDataPointCategory(i) * 2 - 1;
                double w_i = Math.exp(this.lambda * Math.pow(f_t - y_true, 2.0) - (1.0 - this.lambda) * f_t * f_t);
                if (Double.isInfinite(w_i)) {
                    w_i = 50.0;
                }
                weightSum += w_i;
                cds.setWeight(i, w_i);
            }
            for (i = 0; i < cds.size(); ++i) {
                cds.setWeight(i, cds.getWeight(i) / weightSum);
            }
            this.hypoths.add(weak);
            this.hypWeights.add(alpha_m);
        }
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    @Override
    public EmphasisBoost clone() {
        return new EmphasisBoost(this);
    }
}

