/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.boosting;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.exceptions.FailedToFitException;
import jsat.linear.Vec;
import jsat.parameters.Parameterized;
import jsat.utils.DoubleList;

public class AdaBoostM1
implements Classifier,
Parameterized {
    private static final long serialVersionUID = 4205232097748332861L;
    private Classifier weakLearner;
    private int maxIterations;
    protected List<Classifier> hypoths;
    protected List<Double> hypWeights;
    protected CategoricalData predicting;

    public AdaBoostM1(Classifier weakLearner, int maxIterations) {
        this.setWeakLearner(weakLearner);
        this.maxIterations = maxIterations;
    }

    public AdaBoostM1(AdaBoostM1 toCopy) {
        this(toCopy.weakLearner.clone(), toCopy.maxIterations);
        if (toCopy.hypWeights != null) {
            this.hypWeights = new DoubleList(toCopy.hypWeights);
        }
        if (toCopy.hypoths != null) {
            this.hypoths = new ArrayList<Classifier>(toCopy.hypoths.size());
            for (int i = 0; i < toCopy.hypoths.size(); ++i) {
                this.hypoths.add(toCopy.hypoths.get(i).clone());
            }
        }
        if (toCopy.predicting != null) {
            this.predicting = toCopy.predicting.clone();
        }
    }

    public int getMaxIterations() {
        return this.maxIterations;
    }

    public List<Classifier> getModels() {
        return Collections.unmodifiableList(this.hypoths);
    }

    public List<Double> getModelWeights() {
        return Collections.unmodifiableList(this.hypWeights);
    }

    public void setMaxIterations(int maxIterations) {
        if (maxIterations < 1) {
            throw new IllegalArgumentException("Number of iterations must be a positive value, no " + maxIterations);
        }
        this.maxIterations = maxIterations;
    }

    public Classifier getWeakLearner() {
        return this.weakLearner;
    }

    public void setWeakLearner(Classifier weakLearner) {
        if (!weakLearner.supportsWeightedData()) {
            throw new FailedToFitException("WeakLearner must support weighted data to be boosted");
        }
        this.weakLearner = weakLearner;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        if (this.predicting == null) {
            throw new RuntimeException("Classifier has not been trained yet");
        }
        CategoricalResults cr = new CategoricalResults(this.predicting.getNumOfCategories());
        for (int i = 0; i < this.hypoths.size(); ++i) {
            cr.incProb(this.hypoths.get(i).classify(data).mostLikely(), this.hypWeights.get(i));
        }
        cr.normalize();
        return cr;
    }

    @Override
    public void train(ClassificationDataSet dataSet, boolean parallel) {
        this.predicting = dataSet.getPredicting();
        this.hypWeights = new DoubleList(this.maxIterations);
        this.hypoths = new ArrayList<Classifier>(this.maxIterations);
        Vec origWeights = dataSet.getDataWeights();
        for (int i = 0; i < dataSet.size(); ++i) {
            dataSet.setWeight(i, 1.0);
        }
        double scaledBy = dataSet.size();
        boolean[] wasCorrect = new boolean[dataSet.size()];
        for (int t = 0; t < this.maxIterations; ++t) {
            int i;
            this.weakLearner.train(dataSet, parallel);
            double error = 0.0;
            for (int i2 = 0; i2 < dataSet.size(); ++i2) {
                wasCorrect[i2] = this.weakLearner.classify(dataSet.getDataPoint(i2)).mostLikely() == dataSet.getDataPointCategory(i2);
                if (wasCorrect[i2]) continue;
                error += dataSet.getWeight(i2);
            }
            if ((error /= scaledBy) > 0.5 || error == 0.0) {
                return;
            }
            double bt = error / (1.0 - error);
            double Zt = 0.0;
            double newScale = scaledBy;
            for (i = 0; i < wasCorrect.length; ++i) {
                double trueWeight;
                DataPoint dp = dataSet.getDataPoint(i);
                if (wasCorrect[i]) {
                    double w = dataSet.getWeight(i) * bt;
                    dataSet.setWeight(i, w);
                }
                if (1.0 / (trueWeight = dataSet.getWeight(i) / scaledBy) > newScale) {
                    newScale = 1.0 / trueWeight;
                }
                Zt += dataSet.getWeight(i) / scaledBy;
            }
            for (i = 0; i < dataSet.size(); ++i) {
                dataSet.setWeight(i, dataSet.getWeight(i) / scaledBy * newScale / Zt);
            }
            scaledBy = newScale;
            this.hypoths.add(this.weakLearner.clone());
            this.hypWeights.add(Math.log(1.0 / bt));
        }
        for (int i = 0; i < dataSet.size(); ++i) {
            dataSet.setWeight(i, origWeights.get(i));
        }
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    @Override
    public AdaBoostM1 clone() {
        return new AdaBoostM1(this);
    }
}

