/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.boosting;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.parameters.Parameterized;
import jsat.utils.DoubleList;

public class SAMME
implements Classifier,
Parameterized {
    private static final long serialVersionUID = -3584203799253810599L;
    private Classifier weakLearner;
    private int maxIterations;
    private List<Classifier> hypoths;
    private List<Double> hypWeights;
    private CategoricalData predicting;

    public SAMME(Classifier weakLearner, int maxIterations) {
        if (!weakLearner.supportsWeightedData()) {
            throw new RuntimeException("WeakLearner must support weighted data to be boosted");
        }
        this.weakLearner = weakLearner;
        this.maxIterations = maxIterations;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        if (this.predicting == null) {
            throw new RuntimeException("Classifier has not been trained yet");
        }
        CategoricalResults cr = new CategoricalResults(this.predicting.getNumOfCategories());
        for (int i = 0; i < this.hypoths.size(); ++i) {
            cr.incProb(this.hypoths.get(i).classify(data).mostLikely(), this.hypWeights.get(i));
        }
        cr.normalize();
        return cr;
    }

    public List<Classifier> getModels() {
        return Collections.unmodifiableList(this.hypoths);
    }

    public List<Double> getModelWeights() {
        return Collections.unmodifiableList(this.hypWeights);
    }

    @Override
    public void train(ClassificationDataSet dataSet, boolean parallel) {
        this.predicting = dataSet.getPredicting();
        this.hypWeights = new DoubleList(this.maxIterations);
        this.hypoths = new ArrayList<Classifier>();
        int K = this.predicting.getNumOfCategories();
        double logK = Math.log((double)K - 1.0) / Math.log(2.0);
        ClassificationDataSet cds = dataSet.shallowClone();
        for (int i = 0; i < cds.size(); ++i) {
            cds.setWeight(i, 1.0);
        }
        double sumOfWeights = cds.size();
        boolean[] wasCorrect = new boolean[cds.size()];
        for (int t = 0; t < this.maxIterations; ++t) {
            this.weakLearner.train(cds, parallel);
            double error = 0.0;
            for (int i = 0; i < cds.size(); ++i) {
                wasCorrect[i] = this.weakLearner.classify(cds.getDataPoint(i)).mostLikely() == cds.getDataPointCategory(i);
                if (wasCorrect[i]) continue;
                error += cds.getWeight(i);
            }
            if ((error /= sumOfWeights) >= 1.0 - 1.0 / (double)K || error == 0.0) {
                return;
            }
            double am = Math.log((1.0 - error) / error) / Math.log(2.0) + logK;
            for (int i = 0; i < wasCorrect.length; ++i) {
                if (wasCorrect[i]) continue;
                double w = cds.getWeight(i);
                double newW = w * Math.exp(am);
                if (Double.isInfinite(newW)) {
                    newW = 1.0;
                }
                sumOfWeights += newW - w;
                cds.setWeight(i, newW);
            }
            this.hypoths.add(this.weakLearner.clone());
            this.hypWeights.add(am);
        }
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    @Override
    public SAMME clone() {
        SAMME clone = new SAMME(this.weakLearner.clone(), this.maxIterations);
        if (this.hypWeights != null) {
            clone.hypWeights = new DoubleList(this.hypWeights);
        }
        if (this.hypoths != null) {
            clone.hypoths = new ArrayList<Classifier>(this.hypoths.size());
            for (int i = 0; i < this.hypoths.size(); ++i) {
                clone.hypoths.add(this.hypoths.get(i).clone());
            }
        }
        if (this.predicting != null) {
            clone.predicting = this.predicting.clone();
        }
        return clone;
    }
}

