/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.linear;

import jsat.DataSet;
import jsat.SingleWeightVectorModel;
import jsat.classifiers.BaseUpdateableClassifier;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.DataPoint;
import jsat.classifiers.calibration.BinaryScoreClassifier;
import jsat.distributions.Distribution;
import jsat.distributions.LogUniform;
import jsat.distributions.Normal;
import jsat.distributions.Uniform;
import jsat.exceptions.FailedToFitException;
import jsat.exceptions.UntrainedModelException;
import jsat.linear.DenseVector;
import jsat.linear.IndexValue;
import jsat.linear.Matrix;
import jsat.linear.Vec;
import jsat.parameters.Parameterized;

public class SCW
extends BaseUpdateableClassifier
implements BinaryScoreClassifier,
Parameterized,
SingleWeightVectorModel {
    private static final long serialVersionUID = -6721377074407660742L;
    private double C = 1.0;
    private double eta;
    private double phi;
    private double phiSqrd;
    private double zeta;
    private double psi;
    private Mode mode;
    private Vec w;
    private Matrix sigmaM;
    private Vec sigmaV;
    private Vec Sigma_xt;
    private boolean diagonalOnly = false;

    private void zeroOutSigmaXt(Vec x_t) {
        if (this.diagonalOnly && x_t.isSparse()) {
            for (IndexValue iv : x_t) {
                this.Sigma_xt.set(iv.getIndex(), 0.0);
            }
        } else {
            this.Sigma_xt.zeroOut();
        }
    }

    public SCW() {
        this(0.5, Mode.SCWI, true);
    }

    public SCW(double eta, Mode mode, boolean diagonalOnly) {
        this.setEta(eta);
        this.setMode(mode);
        this.setDiagonalOnly(diagonalOnly);
    }

    protected SCW(SCW other) {
        this.C = other.C;
        this.diagonalOnly = other.diagonalOnly;
        this.mode = other.mode;
        this.setEta(other.eta);
        if (other.w != null) {
            this.w = other.w.clone();
        }
        if (other.sigmaM != null) {
            this.sigmaM = other.sigmaM.clone();
        }
        if (other.sigmaV != null) {
            this.sigmaV = other.sigmaV.clone();
        }
        if (other.Sigma_xt != null) {
            this.Sigma_xt = other.Sigma_xt.clone();
        }
    }

    public void setEta(double eta) {
        if (Double.isNaN(eta) || eta < 0.5 || eta > 1.0) {
            throw new IllegalArgumentException("eta must be in [0.5, 1] not " + eta);
        }
        this.eta = eta;
        this.phi = Normal.invcdf(eta, 0.0, 1.0);
        this.phiSqrd = this.phi * this.phi;
        this.zeta = 1.0 + this.phiSqrd;
        this.psi = 1.0 + this.phiSqrd / 2.0;
    }

    public double getEta() {
        return this.eta;
    }

    public void setC(double C2) {
        this.C = C2;
    }

    public double getC() {
        return this.C;
    }

    public void setMode(Mode mode) {
        this.mode = mode;
    }

    public Mode getMode() {
        return this.mode;
    }

    public void setDiagonalOnly(boolean diagonalOnly) {
        this.diagonalOnly = diagonalOnly;
    }

    public boolean isDiagonalOnly() {
        return this.diagonalOnly;
    }

    public Vec getWeightVec() {
        return this.w;
    }

    @Override
    public Vec getRawWeight() {
        return this.w;
    }

    @Override
    public double getBias() {
        return 0.0;
    }

    @Override
    public Vec getRawWeight(int index) {
        if (index < 1) {
            return this.getRawWeight();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override
    public double getBias(int index) {
        if (index < 1) {
            return this.getBias();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override
    public int numWeightsVecs() {
        return 1;
    }

    @Override
    public SCW clone() {
        return new SCW(this);
    }

    @Override
    public void setUp(CategoricalData[] categoricalAttributes, int numericAttributes, CategoricalData predicting) {
        if (numericAttributes <= 0) {
            throw new FailedToFitException("SCW requires numeric attributes to perform classification");
        }
        if (predicting.getNumOfCategories() != 2) {
            throw new FailedToFitException("SCW is a binary classifier");
        }
        this.w = new DenseVector(numericAttributes);
        this.Sigma_xt = new DenseVector(numericAttributes);
        if (this.diagonalOnly) {
            this.sigmaV = new DenseVector(numericAttributes);
            this.sigmaV.mutableAdd(1.0);
        } else {
            this.sigmaM = Matrix.eye(numericAttributes);
        }
    }

    @Override
    public void update(DataPoint dataPoint, double weight, int targetClass) {
        double alpha_t;
        Vec x_t = dataPoint.getNumericalValues();
        double y_t = targetClass * 2 - 1;
        double score = x_t.dot(this.w);
        double v_t = 0.0;
        if (this.diagonalOnly) {
            for (IndexValue iv : x_t) {
                double x_t_i = iv.getValue();
                v_t += x_t_i * x_t_i * this.sigmaV.get(iv.getIndex());
            }
        } else {
            this.sigmaM.multiply(x_t, 1.0, this.Sigma_xt);
            v_t = x_t.dot(this.Sigma_xt);
        }
        if (v_t <= 0.0) {
            throw new FailedToFitException("Numerical issues occured");
        }
        double m_t = y_t * score;
        double loss = Math.max(0.0, this.phi * Math.sqrt(v_t) - m_t);
        if (loss <= 1.0E-15) {
            if (!this.diagonalOnly) {
                this.zeroOutSigmaXt(x_t);
            }
            return;
        }
        if (this.mode == Mode.SCWI || this.mode == Mode.CW) {
            double tmp = Math.max(0.0, (-m_t * this.psi + Math.sqrt(m_t * m_t * this.phiSqrd * this.phiSqrd / 4.0 + v_t * this.phiSqrd * this.zeta)) / (v_t * this.zeta));
            alpha_t = this.mode == Mode.SCWI ? Math.min(this.C, tmp) : tmp;
        } else {
            double n_t = v_t + 1.0 / (2.0 * this.C);
            double gamma = this.phi * Math.sqrt(this.phiSqrd * v_t * v_t * m_t * m_t + 4.0 * n_t * v_t * (n_t + v_t * this.phiSqrd));
            alpha_t = Math.max(0.0, (-(2.0 * m_t * n_t + this.phiSqrd * m_t * v_t) + gamma) / (2.0 * (n_t * n_t + n_t * v_t * this.phiSqrd)));
        }
        if (alpha_t < 1.0E-7) {
            if (!this.diagonalOnly) {
                this.zeroOutSigmaXt(x_t);
            }
            return;
        }
        double u_t = Math.pow(-alpha_t * v_t * this.phi + Math.sqrt(alpha_t * alpha_t * v_t * v_t * this.phiSqrd + 4.0 * v_t), 2.0) / 4.0;
        if (this.diagonalOnly) {
            for (IndexValue iv : x_t) {
                double x_t_i = iv.getValue();
                double tmp = x_t_i * this.sigmaV.get(iv.getIndex());
                this.w.increment(iv.getIndex(), alpha_t * y_t * tmp);
            }
        } else {
            this.w.mutableAdd(alpha_t * y_t, this.Sigma_xt);
        }
        if (this.diagonalOnly) {
            double coef = alpha_t * this.phi * Math.pow(u_t, -0.5);
            for (IndexValue iv : x_t) {
                int idx = iv.getIndex();
                double S_rr = this.sigmaV.get(idx);
                this.sigmaV.set(idx, 1.0 / (1.0 / S_rr + coef * Math.pow(iv.getValue(), 2.0)));
            }
        } else {
            double beta_t = alpha_t * this.phi / (Math.sqrt(u_t) + v_t * alpha_t * this.phi);
            Matrix.OuterProductUpdate(this.sigmaM, this.Sigma_xt, this.Sigma_xt, -beta_t);
            this.zeroOutSigmaXt(x_t);
        }
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        if (this.w == null) {
            throw new UntrainedModelException("Model has not yet ben trained");
        }
        CategoricalResults cr = new CategoricalResults(2);
        double score = this.getScore(data);
        if (score < 0.0) {
            cr.setProb(0, 1.0);
        } else {
            cr.setProb(1, 1.0);
        }
        return cr;
    }

    @Override
    public double getScore(DataPoint dp) {
        return this.w.dot(dp.getNumericalValues());
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    public static Distribution guessC(DataSet d) {
        return new LogUniform(Math.pow(2.0, -4.0), Math.pow(2.0, 4.0));
    }

    public static Distribution guessEta(DataSet d) {
        return new Uniform(0.5, 0.95);
    }

    public static enum Mode {
        CW,
        SCWI,
        SCWII;

    }
}

