/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.linear.kernelized;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import jsat.DataSet;
import jsat.classifiers.BaseUpdateableClassifier;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.DataPoint;
import jsat.distributions.Distribution;
import jsat.distributions.LogUniform;
import jsat.distributions.kernels.KernelTrick;
import jsat.exceptions.FailedToFitException;
import jsat.linear.Vec;
import jsat.parameters.Parameterized;
import jsat.utils.DoubleList;
import jsat.utils.random.RandomUtil;

public class CSKLR
extends BaseUpdateableClassifier
implements Parameterized {
    private static final long serialVersionUID = 2325605193408720811L;
    private double eta;
    private DoubleList alpha;
    private List<Vec> vecs;
    private double curNorm;
    private KernelTrick k;
    private double R;
    private Random rand;
    private UpdateMode mode;
    private double gamma = 2.0;
    private List<Double> accelCache;

    public CSKLR(double eta, KernelTrick k, double R, UpdateMode mode) {
        this.setEta(eta);
        this.setKernel(k);
        this.setR(R);
        this.setMode(mode);
    }

    public static Distribution guessR(DataSet d) {
        return new LogUniform(1.0, 100000.0);
    }

    protected CSKLR(CSKLR toClone) {
        if (toClone.alpha != null) {
            this.alpha = new DoubleList(toClone.alpha);
        }
        if (toClone.vecs != null) {
            this.vecs = new ArrayList<Vec>(toClone.vecs);
        }
        this.curNorm = toClone.curNorm;
        this.mode = toClone.mode;
        this.R = toClone.R;
        this.eta = toClone.eta;
        this.setKernel(toClone.k.clone());
        if (toClone.accelCache != null) {
            this.accelCache = new DoubleList(toClone.accelCache);
        }
        this.gamma = toClone.gamma;
        this.rand = RandomUtil.getRandom();
        this.setEpochs(toClone.getEpochs());
    }

    public void setEta(double eta) {
        if (eta < 0.0 || Double.isNaN(eta) || Double.isInfinite(eta)) {
            throw new IllegalArgumentException("The learning rate should be in (0, Inf), not " + eta);
        }
        this.eta = eta;
    }

    public double getEta() {
        return this.eta;
    }

    public void setR(double R) {
        if (R < 0.0 || Double.isNaN(R) || Double.isInfinite(R)) {
            throw new IllegalArgumentException("The max norm should be in (0, Inf), not " + R);
        }
        this.R = R;
    }

    public double getR() {
        return this.R;
    }

    public void setMode(UpdateMode mode) {
        this.mode = mode;
    }

    public UpdateMode getMode() {
        return this.mode;
    }

    public void setGamma(double gamma) {
        if (gamma < 0.0 || Double.isNaN(gamma) || Double.isInfinite(gamma)) {
            throw new IllegalArgumentException("Gamma must be in (0, Infity), not " + gamma);
        }
        this.gamma = gamma;
    }

    public double getGamma() {
        return this.gamma;
    }

    public void setKernel(KernelTrick k) {
        this.k = k;
    }

    public KernelTrick getKernel() {
        return this.k;
    }

    private double getPreScore(Vec x) {
        return this.k.evalSum(this.vecs, this.accelCache, this.alpha.getBackingArray(), x, 0, this.alpha.size());
    }

    protected static double getScore(double y, double pre) {
        return 1.0 / (1.0 + Math.exp(-y * pre));
    }

    @Override
    public CSKLR clone() {
        return new CSKLR(this);
    }

    @Override
    public void setUp(CategoricalData[] categoricalAttributes, int numericAttributes, CategoricalData predicting) {
        if (predicting.getNumOfCategories() != 2) {
            throw new FailedToFitException("CSKLR supports only binary classification");
        }
        this.alpha = new DoubleList();
        this.vecs = new ArrayList<Vec>();
        this.curNorm = 0.0;
        this.rand = RandomUtil.getRandom();
        if (this.k.supportsAcceleration()) {
            this.accelCache = new DoubleList();
        }
    }

    @Override
    public void update(DataPoint dataPoint, double weight, int targetClass) {
        double y_t = targetClass * 2 - 1;
        Vec x_t = dataPoint.getNumericalValues();
        double pre = this.getPreScore(x_t);
        double score = CSKLR.getScore(y_t, pre);
        switch (this.mode) {
            case NC: {
                break;
            }
            default: {
                double pt = this.mode.pt(y_t, score, pre, this.eta, this.gamma);
                if (!(this.rand.nextDouble() > pt)) break;
                return;
            }
        }
        double alpha_i = -this.eta * y_t * this.mode.grad(y_t, score, pre, this.gamma) * weight;
        this.alpha.add(alpha_i);
        this.vecs.add(x_t);
        this.k.addToCache(x_t, this.accelCache);
        this.curNorm += Math.abs(alpha_i) * this.k.eval(this.vecs.size(), this.vecs.size(), this.vecs, this.accelCache);
        if (this.curNorm > this.R) {
            double coef = this.R / this.curNorm;
            for (int i = 0; i < this.alpha.size(); ++i) {
                this.alpha.set(i, this.alpha.get(i) * coef);
            }
            this.curNorm = coef;
        }
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        CategoricalResults cr = new CategoricalResults(2);
        double p_0 = CSKLR.getScore(-1.0, this.getPreScore(data.getNumericalValues()));
        cr.setProb(0, p_0);
        cr.setProb(1, 1.0 - p_0);
        return cr;
    }

    @Override
    public boolean supportsWeightedData() {
        return true;
    }

    public static enum UpdateMode {
        NC{

            @Override
            protected double pt(double y, double score, double preScore, double eta, double gamma) {
                return 1.0;
            }

            @Override
            protected double grad(double y, double score, double preScore, double gamma) {
                return score - 1.0;
            }
        }
        ,
        MARGIN{

            @Override
            protected double pt(double y, double score, double preScore, double eta, double gamma) {
                return (2.0 - eta) / (2.0 - eta + eta * score);
            }

            @Override
            protected double grad(double y, double score, double preScore, double gamma) {
                return score - 1.0;
            }
        }
        ,
        AUXILIARY_1{

            @Override
            protected double pt(double y, double score, double preScore, double eta, double gamma) {
                double z = y * preScore;
                return Math.log(1.0 + Math.exp(-z)) / Math.log(gamma + Math.exp(-z));
            }

            @Override
            protected double grad(double y, double score, double preScore, double gamma) {
                double z = y * preScore;
                return -1.0 / (1.0 + gamma * Math.exp(z));
            }
        }
        ,
        AUXILIARY_2{

            @Override
            protected double pt(double y, double score, double preScore, double eta, double gamma) {
                double z = y * preScore;
                return Math.log(1.0 + Math.exp(-z)) / Math.log(1.0 + gamma * Math.exp(-z));
            }

            @Override
            protected double grad(double y, double score, double preScore, double gamma) {
                double z = y * preScore;
                return -gamma / (gamma + Math.exp(z));
            }
        }
        ,
        AUXILIARY_3{

            @Override
            protected double pt(double y, double score, double preScore, double eta, double gamma) {
                double z = y * preScore;
                return Math.log(1.0 + Math.exp(-z)) / Math.log(1.0 + Math.exp(-gamma));
            }

            @Override
            protected double grad(double y, double score, double preScore, double gamma) {
                return score - 1.0;
            }
        };


        protected abstract double pt(double var1, double var3, double var5, double var7, double var9);

        protected abstract double grad(double var1, double var3, double var5, double var7);
    }
}

