/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.linear;

import jsat.DataSet;
import jsat.SingleWeightVectorModel;
import jsat.classifiers.BaseUpdateableClassifier;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.DataPoint;
import jsat.classifiers.UpdateableClassifier;
import jsat.classifiers.calibration.BinaryScoreClassifier;
import jsat.distributions.Distribution;
import jsat.distributions.LogUniform;
import jsat.exceptions.FailedToFitException;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.parameters.Parameterized;
import jsat.regression.BaseUpdateableRegressor;
import jsat.regression.RegressionDataSet;
import jsat.regression.UpdateableRegressor;

public class PassiveAggressive
implements UpdateableClassifier,
BinaryScoreClassifier,
UpdateableRegressor,
Parameterized,
SingleWeightVectorModel {
    private static final long serialVersionUID = -7130964391528405832L;
    private int epochs;
    private double C = 0.01;
    private double eps = 0.001;
    private Vec w;
    private Mode mode;

    public PassiveAggressive() {
        this(10, Mode.PA1);
    }

    public PassiveAggressive(int epochs, Mode mode) {
        this.epochs = epochs;
        this.mode = mode;
    }

    public void setC(double C2) {
        if (Double.isNaN(C2) || Double.isInfinite(C2) || C2 <= 0.0) {
            throw new ArithmeticException("Aggressiveness must be a positive constant");
        }
        this.C = C2;
    }

    public double getC() {
        return this.C;
    }

    public void setMode(Mode mode) {
        this.mode = mode;
    }

    public Mode getMode() {
        return this.mode;
    }

    public void setEps(double eps) {
        this.eps = eps;
    }

    public double getEps() {
        return this.eps;
    }

    public void setEpochs(int epochs) {
        if (epochs < 1) {
            throw new IllegalArgumentException("epochs must be a positive value");
        }
        this.epochs = epochs;
    }

    public int getEpochs() {
        return this.epochs;
    }

    @Override
    public Vec getRawWeight() {
        return this.w;
    }

    @Override
    public double getBias() {
        return 0.0;
    }

    @Override
    public Vec getRawWeight(int index) {
        if (index < 1) {
            return this.getRawWeight();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override
    public double getBias(int index) {
        if (index < 1) {
            return this.getBias();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override
    public int numWeightsVecs() {
        return 1;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        CategoricalResults cr = new CategoricalResults(2);
        if (this.getScore(data) > 0.0) {
            cr.setProb(1, 1.0);
        } else {
            cr.setProb(0, 1.0);
        }
        return cr;
    }

    @Override
    public double getScore(DataPoint dp) {
        return dp.getNumericalValues().dot(this.w);
    }

    @Override
    public void train(ClassificationDataSet dataSet, boolean parallel) {
        this.train(dataSet);
    }

    @Override
    public void train(ClassificationDataSet dataSet) {
        BaseUpdateableClassifier.trainEpochs(dataSet, this, this.epochs);
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    @Override
    public void setUp(CategoricalData[] categoricalAttributes, int numericAttributes, CategoricalData predicting) {
        if (predicting.getNumOfCategories() != 2) {
            throw new FailedToFitException("Only supports binary classification problems");
        }
        if (numericAttributes < 1) {
            throw new FailedToFitException("only suppors learning from numeric attributes");
        }
        this.w = new DenseVector(numericAttributes);
    }

    @Override
    public void setUp(CategoricalData[] categoricalAttributes, int numericAttributes) {
        if (numericAttributes < 1) {
            throw new FailedToFitException("only suppors learning from numeric attributes");
        }
        this.w = new DenseVector(numericAttributes);
    }

    @Override
    public void update(DataPoint dataPoint, double weight, int targetClass) {
        int y_t = targetClass * 2 - 1;
        Vec x = dataPoint.getNumericalValues();
        double dot = x.dot(this.w);
        double loss = Math.max(0.0, 1.0 - (double)y_t * dot);
        if (loss == 0.0) {
            return;
        }
        double tau = this.getCorrection(loss, x);
        this.w.mutableAdd((double)y_t * tau, x);
    }

    @Override
    public void update(DataPoint dataPoint, double weight, double targetValue) {
        Vec x = dataPoint.getNumericalValues();
        double y_t = targetValue;
        double y_p = x.dot(this.w);
        double loss = Math.max(0.0, Math.abs(y_p - y_t) - this.eps);
        if (loss == 0.0) {
            return;
        }
        double tau = this.getCorrection(loss, x);
        this.w.mutableAdd(Math.signum(y_t - y_p) * tau, x);
    }

    private double getCorrection(double loss, Vec x) {
        double xNorm = Math.pow(x.pNorm(2.0), 2.0);
        if (this.mode == Mode.PA1) {
            return Math.min(this.C, loss / xNorm);
        }
        if (this.mode == Mode.PA2) {
            return loss / (xNorm + 1.0 / (2.0 * this.C));
        }
        return loss / xNorm;
    }

    @Override
    public double regress(DataPoint data) {
        return this.w.dot(data.getNumericalValues());
    }

    @Override
    public void train(RegressionDataSet dataSet, boolean parallel) {
        this.train(dataSet);
    }

    @Override
    public void train(RegressionDataSet dataSet) {
        BaseUpdateableRegressor.trainEpochs(dataSet, this, this.epochs);
    }

    @Override
    public PassiveAggressive clone() {
        PassiveAggressive clone = new PassiveAggressive(this.epochs, this.mode);
        clone.eps = this.eps;
        clone.C = this.C;
        if (this.w != null) {
            clone.w = this.w;
        }
        return clone;
    }

    public static Distribution guessC(DataSet d) {
        return new LogUniform(0.001, 100.0);
    }

    public static enum Mode {
        PA,
        PA1,
        PA2;

    }
}

