/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.linear;

import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.Random;
import jsat.DataSet;
import jsat.SimpleWeightVectorModel;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.WarmClassifier;
import jsat.distributions.Distribution;
import jsat.distributions.LogUniform;
import jsat.distributions.Uniform;
import jsat.exceptions.FailedToFitException;
import jsat.linear.DenseVector;
import jsat.linear.IndexValue;
import jsat.linear.Vec;
import jsat.lossfunctions.LogisticLoss;
import jsat.lossfunctions.LossC;
import jsat.lossfunctions.LossFunc;
import jsat.lossfunctions.LossMC;
import jsat.lossfunctions.LossR;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.regression.WarmRegressor;
import jsat.utils.IntList;
import jsat.utils.ListUtils;
import jsat.utils.random.RandomUtil;

public class SDCA
implements Classifier,
Regressor,
Parameterized,
SimpleWeightVectorModel,
WarmClassifier,
WarmRegressor {
    private LossFunc loss;
    private boolean useBias = true;
    private double tol = 0.001;
    private double lambda;
    private double alpha = 0.5;
    private int max_epochs = 200;
    private double[] dual_alphas;
    protected int epochs_taken;
    private Vec[] ws;
    private double[] bs;

    public SDCA() {
        this(1.0E-5);
    }

    public SDCA(double lambda) {
        this(lambda, new LogisticLoss());
    }

    public SDCA(double lambda, LossFunc loss) {
        this.setLoss(loss);
        this.setLambda(lambda);
    }

    public SDCA(SDCA toCopy) {
        this.loss = toCopy.loss.clone();
        this.useBias = toCopy.useBias;
        this.tol = toCopy.tol;
        this.lambda = toCopy.lambda;
        this.alpha = toCopy.alpha;
        this.max_epochs = toCopy.max_epochs;
        this.epochs_taken = toCopy.epochs_taken;
        if (toCopy.dual_alphas != null) {
            this.dual_alphas = Arrays.copyOf(toCopy.dual_alphas, toCopy.dual_alphas.length);
        }
        if (toCopy.ws != null) {
            this.ws = new Vec[toCopy.ws.length];
            this.bs = new double[toCopy.bs.length];
            for (int i = 0; i < toCopy.ws.length; ++i) {
                this.ws[i] = toCopy.ws[i].clone();
                this.bs[i] = toCopy.bs[i];
            }
        }
    }

    public void setUseBias(boolean useBias) {
        this.useBias = useBias;
    }

    public boolean isUseBias() {
        return this.useBias;
    }

    @Parameter.WarmParameter(prefLowToHigh=false)
    public void setLambda(double lambda) {
        if (lambda <= 0.0 || Double.isInfinite(lambda) || Double.isNaN(lambda)) {
            throw new IllegalArgumentException("Regularization term lambda must be a positive value, not " + lambda);
        }
        this.lambda = lambda;
    }

    public double getLambda() {
        return this.lambda;
    }

    public void setAlpha(double alpha) {
        if (alpha < 0.0 || alpha > 1.0 || Double.isNaN(alpha)) {
            throw new IllegalArgumentException("alpha must be in [0, 1], not " + alpha);
        }
        this.alpha = alpha;
    }

    public double getAlpha() {
        return this.alpha;
    }

    public void setMaxIters(int maxOuterIters) {
        if (maxOuterIters < 1) {
            throw new IllegalArgumentException("Number of training iterations must be positive, not " + maxOuterIters);
        }
        this.max_epochs = maxOuterIters;
    }

    public int getMaxIters() {
        return this.max_epochs;
    }

    public void setTolerance(double e_out) {
        if (e_out <= 0.0 || Double.isNaN(e_out)) {
            throw new IllegalArgumentException("convergence tolerance paramter must be positive, not " + e_out);
        }
        this.tol = e_out;
    }

    public double getTolerance() {
        return this.tol;
    }

    public void setLoss(LossFunc loss) {
        this.loss = loss;
    }

    public LossFunc getLoss() {
        return this.loss;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        Vec x = data.getNumericalValues();
        if (this.ws.length == 1) {
            return ((LossC)this.loss).getClassification(this.ws[0].dot(x) + this.bs[0]);
        }
        DenseVector pred = new DenseVector(this.ws.length);
        for (int i = 0; i < this.ws.length; ++i) {
            ((Vec)pred).set(i, this.ws[i].dot(x) + this.bs[i]);
        }
        ((LossMC)this.loss).process(pred, pred);
        return ((LossMC)this.loss).getClassification(pred);
    }

    @Override
    public double regress(DataPoint data) {
        Vec x = data.getNumericalValues();
        return ((LossR)this.loss).getRegression(this.ws[0].dot(x) + this.bs[0]);
    }

    @Override
    public void train(ClassificationDataSet dataSet, boolean parallel) {
        this.train(dataSet);
    }

    @Override
    public void train(ClassificationDataSet dataSet) {
        if (dataSet.getPredicting().getNumOfCategories() != 2) {
            throw new RuntimeException("Current SDCA implementation only support binary classification problems");
        }
        double[] targets = new double[dataSet.size()];
        for (int i = 0; i < targets.length; ++i) {
            targets[i] = dataSet.getDataPointCategory(i) * 2 - 1;
        }
        this.trainProxSDCA(dataSet, targets, null);
    }

    @Override
    public void train(ClassificationDataSet dataSet, Classifier warmSolution, boolean parallel) {
        this.train(dataSet, warmSolution);
    }

    @Override
    public void train(ClassificationDataSet dataSet, Classifier warmSolution) {
        if (warmSolution == null || !(warmSolution instanceof SDCA)) {
            throw new FailedToFitException("SDCA implementation can only be warm-started from another instance of SDCA");
        }
        if (dataSet.getPredicting().getNumOfCategories() != 2) {
            throw new RuntimeException("Current SDCA implementation only support binary classification problems");
        }
        double[] targets = new double[dataSet.size()];
        for (int i = 0; i < targets.length; ++i) {
            targets[i] = dataSet.getDataPointCategory(i) * 2 - 1;
        }
        this.trainProxSDCA(dataSet, targets, ((SDCA)warmSolution).dual_alphas);
    }

    @Override
    public void train(RegressionDataSet dataSet, boolean parallel) {
        this.train(dataSet);
    }

    @Override
    public void train(RegressionDataSet dataSet) {
        double[] targets = new double[dataSet.size()];
        for (int i = 0; i < targets.length; ++i) {
            targets[i] = dataSet.getTargetValue(i);
        }
        this.trainProxSDCA(dataSet, targets, null);
    }

    @Override
    public void train(RegressionDataSet dataSet, Regressor warmSolution, boolean parallel) {
        this.train(dataSet, warmSolution);
    }

    @Override
    public void train(RegressionDataSet dataSet, Regressor warmSolution) {
        double[] targets = new double[dataSet.size()];
        for (int i = 0; i < targets.length; ++i) {
            targets[i] = dataSet.getTargetValue(i);
        }
        this.trainProxSDCA(dataSet, targets, ((SDCA)warmSolution).dual_alphas);
    }

    private void trainProxSDCA(DataSet dataSet, double[] targets, double[] warm_alphas) {
        DenseVector w_lazy;
        double[] w_lazy_backing;
        double tol_effective;
        double lambda_effective;
        double sigma_p;
        int i;
        int i2;
        int N = dataSet.size();
        int D2 = dataSet.getNumNumericalVars();
        this.ws = new Vec[]{new DenseVector(D2)};
        DenseVector v = new DenseVector(D2);
        this.bs = new double[1];
        double[] x_norms = new double[N];
        double scaling = 1.0;
        boolean is_regression = dataSet instanceof RegressionDataSet;
        for (i2 = 0; i2 < N; ++i2) {
            x_norms[i2] = dataSet.getDataPoint(i2).getNumericalValues().pNorm(2.0);
            if (is_regression) continue;
            scaling = Math.max(scaling, x_norms[i2]);
        }
        i2 = 0;
        while (i2 < N) {
            int n = i2++;
            x_norms[n] = x_norms[n] / scaling;
        }
        if (this.alpha == 1.0) {
            double y_bar = 0.0;
            for (i = 0; i < N; ++i) {
                y_bar += this.loss.getLoss(0.0, targets[i]);
            }
            sigma_p = this.lambda;
            lambda_effective = this.tol * Math.pow(this.lambda / Math.max(y_bar /= (double)N, 1.0E-7), 2.0);
            tol_effective = this.tol / 2.0;
        } else {
            lambda_effective = this.lambda;
            sigma_p = this.alpha / (1.0 - this.alpha);
            tol_effective = this.tol;
        }
        if (this.alpha > 0.0) {
            w_lazy_backing = new double[D2];
            w_lazy = new DenseVector(w_lazy_backing);
        } else {
            w_lazy_backing = null;
            w_lazy = v;
        }
        if (warm_alphas == null) {
            this.dual_alphas = new double[N];
        } else {
            if (N != warm_alphas.length) {
                throw new FailedToFitException("SDCA only supports warm-start training from the same dataset. A dataset of side " + N + " was given for training, but the warm solution was trained on " + warm_alphas.length + " points.");
            }
            this.dual_alphas = Arrays.copyOf(warm_alphas, warm_alphas.length);
            for (i = 0; i < N; ++i) {
                v.mutableAdd(this.dual_alphas[i], dataSet.getDataPoint(i).getNumericalValues());
                if (!this.useBias) continue;
                this.bs[0] = this.bs[0] + this.dual_alphas[i];
            }
            v.mutableDivide(scaling * lambda_effective * (double)N);
            this.bs[0] = this.bs[0] / (scaling * lambda_effective * (double)N);
        }
        Random rand = RandomUtil.getRandom();
        double gamma = this.loss.lipschitz();
        IntList epoch_order = new IntList(N);
        ListUtils.addRange(epoch_order, 0, N, 1);
        this.epochs_taken = 0;
        int primal_converg_check = 0;
        for (int epoch = 0; epoch < this.max_epochs; ++epoch) {
            double prevPrimal = Double.POSITIVE_INFINITY;
            ++this.epochs_taken;
            double dual_loss_est = 0.0;
            double primal_loss_est = 0.0;
            Collections.shuffle(epoch_order, rand);
            Iterator iterator = epoch_order.iterator();
            while (iterator.hasNext()) {
                double raw_score;
                double lossD;
                double u;
                double q;
                double q_sqrd;
                int i3 = (Integer)iterator.next();
                double alpha_i_prev = this.dual_alphas[i3];
                Vec x = dataSet.getDataPoint(i3).getNumericalValues();
                double y = targets[i3];
                if (this.alpha > 0.0) {
                    for (IndexValue iv : x) {
                        int j = iv.getIndex();
                        double v_j = v.get(j);
                        double v_j_sign = Math.signum(v_j);
                        double v_j_abs = Math.abs(v_j);
                        w_lazy_backing[j] = v_j_sign * Math.max(v_j_abs - sigma_p, 0.0);
                    }
                }
                if ((q_sqrd = (q = (u = -(lossD = this.loss.getDeriv(raw_score = w_lazy.dot(x) / scaling + this.bs[0], y))) - alpha_i_prev) * q) <= 1.0E-32) continue;
                double phi_i = this.loss.getLoss(raw_score, y);
                double conjg = this.loss.getConjugate(-alpha_i_prev, raw_score, y);
                double x_norm = x_norms[i3];
                double x_norm_sqrd = x_norm * x_norm;
                double denom = q_sqrd * (gamma + x_norm_sqrd / (lambda_effective * (double)N));
                double s = (phi_i + conjg + raw_score * alpha_i_prev + gamma * q_sqrd / 2.0) / denom;
                s = Math.min(1.0, s);
                primal_loss_est += phi_i;
                if (!Double.isInfinite(conjg)) {
                    dual_loss_est += -conjg;
                }
                if (s == 0.0) continue;
                double alpha_i_delta = s * q;
                int n = i3;
                this.dual_alphas[n] = this.dual_alphas[n] + alpha_i_delta;
                v.mutableAdd(alpha_i_delta / (scaling * lambda_effective * (double)N), x);
                if (!this.useBias) continue;
                this.bs[0] = this.bs[0] + alpha_i_delta / (scaling * lambda_effective * (double)N);
            }
            double gap = Math.abs(primal_loss_est - dual_loss_est) / (double)N;
            if (gap < tol_effective) break;
            if (prevPrimal - primal_loss_est / (double)N < tol_effective / 5.0) {
                if (primal_converg_check++ > 10) {
                    break;
                }
            } else {
                primal_converg_check = 0;
            }
            prevPrimal = primal_loss_est / (double)N;
        }
        for (int j = 0; j < D2; ++j) {
            double v_j = v.get(j);
            double v_j_sign = Math.signum(v_j);
            double v_j_abs = Math.abs(v_j);
            this.ws[0].set(j, v_j_sign * Math.max(v_j_abs - sigma_p, 0.0) / scaling);
        }
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    @Override
    public SDCA clone() {
        return new SDCA(this);
    }

    @Override
    public Vec getRawWeight(int index) {
        return this.ws[index];
    }

    @Override
    public double getBias(int index) {
        return this.bs[index];
    }

    @Override
    public int numWeightsVecs() {
        return this.ws.length;
    }

    @Override
    public boolean warmFromSameDataOnly() {
        return true;
    }

    public static Distribution guessLambda(DataSet d) {
        int N = d.size();
        return new LogUniform(1.0 / (double)(N * 50), Math.min(1.0 / (double)(N / 50), 1.0));
    }

    public static Distribution guessAlpha(DataSet d) {
        return new Uniform(0.0, 0.5);
    }
}

