/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.boosting;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.exceptions.FailedToFitException;
import jsat.exceptions.UntrainedModelException;
import jsat.parameters.Parameterized;
import jsat.regression.MultipleLinearRegression;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;

public class LogitBoost
implements Classifier,
Parameterized {
    private static final long serialVersionUID = 1621062168467402062L;
    protected double fScaleConstant = 0.5;
    protected List<Regressor> baseLearners;
    protected Regressor baseLearner;
    private int maxIterations;
    private double zMax = 3.0;

    public LogitBoost(int M) {
        this(new MultipleLinearRegression(true), M);
    }

    public LogitBoost(Regressor baseLearner, int M) {
        if (!baseLearner.supportsWeightedData()) {
            throw new RuntimeException("Base Learner must support weighted data points to be boosted");
        }
        this.baseLearner = baseLearner;
        this.maxIterations = M;
    }

    public List<Regressor> getModels() {
        return Collections.unmodifiableList(this.baseLearners);
    }

    public void setMaxIterations(int maxIterations) {
        this.maxIterations = maxIterations;
    }

    public int getMaxIterations() {
        return this.maxIterations;
    }

    public void setzMax(double zMax) {
        if (Double.isInfinite(zMax) || Double.isNaN(zMax) || zMax <= 0.0) {
            throw new ArithmeticException("Invalid penalty given: " + zMax);
        }
        this.zMax = zMax;
    }

    public double getzMax() {
        return this.zMax;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        if (this.baseLearner == null) {
            throw new UntrainedModelException("Model has not yet been trained");
        }
        double p = this.P(data);
        CategoricalResults cr = new CategoricalResults(2);
        cr.setProb(1, p);
        cr.setProb(0, 1.0 - p);
        return cr;
    }

    @Override
    public void train(ClassificationDataSet dataSet, boolean parallel) {
        if (dataSet.getClassSize() != 2) {
            throw new FailedToFitException("LogitBoost only supports binary decision tasks, not " + dataSet.getClassSize() + " class problems");
        }
        RegressionDataSet rds = new RegressionDataSet(dataSet.getAsFloatDPPList());
        this.baseLearners = new ArrayList<Regressor>(this.maxIterations);
        int N = dataSet.size();
        for (int m = 0; m < this.maxIterations; ++m) {
            for (int i = 0; i < N; ++i) {
                DataPoint dp = rds.getDataPoint(i);
                double pi = this.P(dp);
                double zi = dataSet.getDataPointCategory(i) == 1 ? Math.min(this.zMax, 1.0 / pi) : Math.max(-this.zMax, -1.0 / (1.0 - pi));
                double wi = Math.max(pi * (1.0 - pi), 2.0E-15);
                rds.setWeight(i, wi);
                rds.setTargetValue(i, zi);
            }
            Regressor f = this.baseLearner.clone();
            f.train(rds);
            this.baseLearners.add(f);
        }
    }

    private double F(DataPoint x) {
        double fx = 0.0;
        for (Regressor fm : this.baseLearners) {
            fx += fm.regress(x);
        }
        return fx * this.fScaleConstant;
    }

    protected double P(DataPoint x) {
        double fx = this.F(x);
        double efx = Math.exp(fx);
        double enfx = Math.exp(-fx);
        if (Double.isInfinite(efx) && efx > 0.0 && enfx < 1.0E-15) {
            return 1.0;
        }
        return efx / (efx + enfx);
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    @Override
    public LogitBoost clone() {
        LogitBoost clone = new LogitBoost(this.maxIterations);
        clone.zMax = this.zMax;
        if (this.baseLearner != null) {
            clone.baseLearner = this.baseLearner.clone();
        }
        if (this.baseLearners != null) {
            clone.baseLearners = new ArrayList<Regressor>(this.baseLearners.size());
            for (Regressor r : this.baseLearners) {
                clone.baseLearners.add(r.clone());
            }
        }
        return clone;
    }
}

