/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.linear;

import java.util.Arrays;
import jsat.SingleWeightVectorModel;
import jsat.classifiers.BaseUpdateableClassifier;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.DataPoint;
import jsat.classifiers.calibration.BinaryScoreClassifier;
import jsat.exceptions.FailedToFitException;
import jsat.linear.DenseVector;
import jsat.linear.IndexValue;
import jsat.linear.Vec;
import jsat.parameters.Parameterized;
import jsat.regression.BaseUpdateableRegressor;
import jsat.regression.RegressionDataSet;
import jsat.regression.UpdateableRegressor;

public class STGD
extends BaseUpdateableClassifier
implements UpdateableRegressor,
BinaryScoreClassifier,
Parameterized,
SingleWeightVectorModel {
    private static final long serialVersionUID = 5753298014967370769L;
    private Vec w;
    private int K;
    private double learningRate;
    private double threshold;
    private double gravity;
    private int time;
    private int[] t;

    public STGD(int K, double learningRate, double threshold, double gravity) {
        this.setK(K);
        this.setLearningRate(learningRate);
        this.setThreshold(threshold);
        this.setGravity(gravity);
    }

    protected STGD(STGD toCopy) {
        if (toCopy.w != null) {
            this.w = toCopy.w.clone();
        }
        this.K = toCopy.K;
        this.learningRate = toCopy.learningRate;
        this.threshold = toCopy.threshold;
        this.gravity = toCopy.gravity;
        this.time = toCopy.time;
        if (toCopy.t != null) {
            this.t = Arrays.copyOf(toCopy.t, toCopy.t.length);
        }
    }

    public void setK(int K) {
        if (K < 1) {
            throw new IllegalArgumentException("K must be positive, not " + K);
        }
        this.K = K;
    }

    public int getK() {
        return this.K;
    }

    public void setLearningRate(double learningRate) {
        if (Double.isInfinite(learningRate) || Double.isNaN(learningRate) || learningRate <= 0.0) {
            throw new IllegalArgumentException("Learning rate must be positive, not " + learningRate);
        }
        this.learningRate = learningRate;
    }

    public double getLearningRate() {
        return this.learningRate;
    }

    public void setThreshold(double threshold) {
        if (Double.isNaN(threshold) || threshold <= 0.0) {
            throw new IllegalArgumentException("Threshold must be positive, not " + threshold);
        }
        this.threshold = threshold;
    }

    public double getThreshold() {
        return this.threshold;
    }

    public void setGravity(double gravity) {
        if (Double.isInfinite(gravity) || Double.isNaN(gravity) || gravity <= 0.0) {
            throw new IllegalArgumentException("Gravity must be positive, not " + gravity);
        }
        this.gravity = gravity;
    }

    public double getGravity() {
        return this.gravity;
    }

    @Override
    public Vec getRawWeight() {
        return this.w;
    }

    @Override
    public double getBias() {
        return 0.0;
    }

    @Override
    public Vec getRawWeight(int index) {
        if (index < 1) {
            return this.getRawWeight();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override
    public double getBias(int index) {
        if (index < 1) {
            return this.getBias();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override
    public int numWeightsVecs() {
        return 1;
    }

    @Override
    public STGD clone() {
        return new STGD(this);
    }

    @Override
    public void setUp(CategoricalData[] categoricalAttributes, int numericAttributes, CategoricalData predicting) {
        if (predicting.getNumOfCategories() != 2) {
            throw new FailedToFitException("STGD supports only binary classification");
        }
        this.setUp(categoricalAttributes, numericAttributes);
    }

    @Override
    public void setUp(CategoricalData[] categoricalAttributes, int numericAttributes) {
        if (numericAttributes < 1) {
            throw new FailedToFitException("STGD requires numeric features");
        }
        this.w = new DenseVector(numericAttributes);
        this.t = new int[numericAttributes];
    }

    @Override
    public void train(RegressionDataSet dataSet, boolean parallel) {
        this.train(dataSet);
    }

    @Override
    public void train(RegressionDataSet dataSet) {
        BaseUpdateableRegressor.trainEpochs(dataSet, this, this.getEpochs());
    }

    private static double T(double v_j, double a, double theta) {
        if (v_j >= 0.0 && v_j <= theta) {
            return Math.max(0.0, v_j - a);
        }
        if (v_j <= 0.0 && v_j >= -theta) {
            return Math.min(0.0, v_j + a);
        }
        return v_j;
    }

    @Override
    public void update(DataPoint dataPoint, double weight, int targetClass) {
        ++this.time;
        Vec x = dataPoint.getNumericalValues();
        int y = targetClass * 2 - 1;
        int yHat = (int)Math.signum(this.w.dot(x));
        if (yHat == y) {
            return;
        }
        this.performUpdate(x, y, yHat);
    }

    @Override
    public void update(DataPoint dataPoint, double weight, double y) {
        ++this.time;
        Vec x = dataPoint.getNumericalValues();
        double yHat = this.w.dot(x);
        this.performUpdate(x, y, yHat);
    }

    private void performUpdate(Vec x, double y, double yHat) {
        for (IndexValue iv : x) {
            int j = iv.getIndex();
            this.w.set(j, STGD.T(this.w.get(j) + 2.0 * this.learningRate * (y - yHat) * iv.getValue(), (double)((this.time - this.t[j]) / this.K) * this.gravity * this.learningRate, this.threshold));
            int n = j;
            this.t[n] = this.t[n] + (this.time - this.t[j]) / this.K * this.K;
        }
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        CategoricalResults cr = new CategoricalResults(2);
        if (this.getScore(data) > 0.0) {
            cr.setProb(1, 1.0);
        } else {
            cr.setProb(0, 1.0);
        }
        return cr;
    }

    @Override
    public double regress(DataPoint data) {
        return this.getScore(data);
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    @Override
    public double getScore(DataPoint dp) {
        return this.w.dot(dp.getNumericalValues());
    }
}

