/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.svm;

import java.util.Arrays;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.stream.IntStream;
import jsat.DataSet;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.WarmClassifier;
import jsat.classifiers.calibration.BinaryScoreClassifier;
import jsat.classifiers.svm.PlattSMO;
import jsat.classifiers.svm.SupportVectorLearner;
import jsat.distributions.Distribution;
import jsat.distributions.kernels.KernelTrick;
import jsat.distributions.kernels.LinearKernel;
import jsat.exceptions.FailedToFitException;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.regression.WarmRegressor;
import jsat.utils.concurrent.ParallelUtils;

public class LSSVM
extends SupportVectorLearner
implements BinaryScoreClassifier,
Regressor,
Parameterized,
WarmRegressor,
WarmClassifier {
    private static final long serialVersionUID = -7569924400631719451L;
    protected double b = 0.0;
    protected double b_low;
    protected double b_up;
    private double C = 1.0;
    private int i_up;
    private int i_low;
    private double[] fcache;
    private double dualObjective;
    private static double epsilon = 1.0E-12;
    private static double tol = 0.001;

    public LSSVM() {
        this(new LinearKernel());
    }

    public LSSVM(KernelTrick kernel) {
        this(kernel, SupportVectorLearner.CacheMode.NONE);
    }

    public LSSVM(KernelTrick kernel, SupportVectorLearner.CacheMode cacheMode) {
        super(kernel, cacheMode);
    }

    public LSSVM(LSSVM toCopy) {
        super(toCopy.getKernel().clone(), toCopy.getCacheMode());
        this.b_low = toCopy.b_low;
        this.b_up = toCopy.b_up;
        this.i_up = toCopy.i_up;
        this.i_low = toCopy.i_low;
        this.C = toCopy.C;
        if (toCopy.alphas != null) {
            this.alphas = Arrays.copyOf(toCopy.alphas, toCopy.alphas.length);
        }
        if (toCopy.fcache != null) {
            this.fcache = Arrays.copyOf(toCopy.fcache, toCopy.fcache.length);
        }
    }

    @Parameter.WarmParameter(prefLowToHigh=true)
    public void setC(double C2) {
        if (C2 <= 0.0 || Double.isNaN(C2) || Double.isInfinite(C2)) {
            throw new IllegalArgumentException("C must be in (0, Infty), not " + C2);
        }
        this.C = C2;
    }

    public double getC() {
        return this.C;
    }

    private boolean takeStep(int i1, int i2, ExecutorService ex, boolean parallel) throws InterruptedException, ExecutionException {
        double a1;
        double k22;
        double alph1 = this.alphas[i1];
        double alph2 = this.alphas[i2];
        double F1 = this.fcache[i1];
        double F2 = this.fcache[i2];
        double gamma = alph1 + alph2;
        double k11 = this.kEval(i1, i1);
        double k12 = this.kEval(i2, i1);
        double eta = 2.0 * k12 - k11 - (k22 = this.kEval(i2, i2));
        double a2 = alph2 - (F1 - F2) / eta;
        if (Math.abs(a2 - alph2) < epsilon * (a2 + alph2 + epsilon)) {
            return false;
        }
        this.alphas[i1] = a1 = gamma - a2;
        this.alphas[i2] = a2;
        double t = (F1 - F2) / eta;
        this.dualObjective -= eta / 2.0 * t * t;
        this.b_up = Double.NEGATIVE_INFINITY;
        this.b_low = Double.POSITIVE_INFINITY;
        ParallelUtils.run(parallel, this.fcache.length, (from, to) -> {
            int i_low_cand = from;
            int i_up_cand = from;
            double b_up_p = Double.NEGATIVE_INFINITY;
            double b_low_p = Double.POSITIVE_INFINITY;
            for (int i = from; i < to; ++i) {
                double k_i1 = this.kEval(i1, i);
                double k_i2 = this.kEval(i2, i);
                int n = i;
                double d = this.fcache[n] = this.fcache[n] + ((a1 - alph1) * k_i1 + (a2 - alph2) * k_i2);
                double Fi = d;
                if (Fi > b_up_p) {
                    b_up_p = Fi;
                    i_up_cand = i;
                }
                if (!(Fi < b_low_p)) continue;
                b_low_p = Fi;
                i_low_cand = i;
            }
            double[] dArray = this.fcache;
            synchronized (this.fcache) {
                if (this.fcache[i_up_cand] > this.b_up) {
                    this.b_up = this.fcache[i_up_cand];
                    this.i_up = i_up_cand;
                }
                if (this.fcache[i_low_cand] < this.b_low) {
                    this.b_low = this.fcache[i_low_cand];
                    this.i_low = i_low_cand;
                }
                // ** MonitorExit[var19_14] (shouldn't be in output)
                return;
            }
        }, ex);
        return true;
    }

    @Override
    public boolean warmFromSameDataOnly() {
        return true;
    }

    private double computeDualityGap(boolean fast, boolean parallel) throws InterruptedException, ExecutionException {
        double gap = 0.0;
        if (fast) {
            this.b = (this.b_up + this.b_low) / 2.0;
        } else {
            this.b = ParallelUtils.streamP(IntStream.range(0, this.alphas.length), parallel).mapToDouble(i -> this.fcache[i] - this.alphas[i] / this.C).sum();
            this.b /= (double)this.alphas.length;
        }
        gap = ParallelUtils.streamP(IntStream.range(0, this.alphas.length), parallel).mapToDouble(i -> {
            double x_i = this.b + this.alphas[i] / this.C - this.fcache[i];
            return this.alphas[i] * (this.fcache[i] - 0.5 * this.alphas[i] / this.C) + this.C * x_i * x_i / 2.0;
        }).sum();
        return gap;
    }

    private void initializeVariables(double[] targets, LSSVM warmSolution, DataSet data) {
        this.alphas = new double[targets.length];
        this.fcache = new double[targets.length];
        this.dualObjective = 0.0;
        if (warmSolution != null) {
            if (warmSolution.alphas.length != this.alphas.length) {
                throw new FailedToFitException("Warm LS-SVM solution could not have been trained on the sama data, different number of alpha values present");
            }
            double C_ratio = this.C / warmSolution.C;
            for (int i = 0; i < targets.length; ++i) {
                this.alphas[i] = warmSolution.alphas[i];
                this.fcache[i] = warmSolution.fcache[i] - (C_ratio - 1.0) * warmSolution.alphas[i] / this.C;
                this.dualObjective += this.alphas[i] * (targets[i] - this.fcache[i]);
            }
            this.dualObjective /= 2.0;
        } else {
            for (int i = 0; i < targets.length; ++i) {
                this.fcache[i] = -targets[i];
            }
        }
        this.b_up = Double.NEGATIVE_INFINITY;
        this.b_low = Double.POSITIVE_INFINITY;
        for (int i = 0; i < this.fcache.length; ++i) {
            double Fi = this.fcache[i];
            if (Fi > this.b_up) {
                this.b_up = Fi;
                this.i_up = i;
            }
            if (!(Fi < this.b_low)) continue;
            this.b_low = Fi;
            this.i_low = i;
        }
        this.setCacheMode(this.getCacheMode());
    }

    @Override
    public double getScore(DataPoint dp) {
        return this.regress(dp);
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        CategoricalResults cr = new CategoricalResults(2);
        if (this.regress(data) > 0.0) {
            cr.setProb(1, 1.0);
        } else {
            cr.setProb(0, 1.0);
        }
        return cr;
    }

    @Override
    public void train(ClassificationDataSet dataSet, boolean parallel) {
        this.train(dataSet, null, parallel);
    }

    @Override
    public void train(RegressionDataSet dataSet, Regressor warmSolution, boolean parallel) {
        if (warmSolution != null && !(warmSolution instanceof LSSVM)) {
            throw new FailedToFitException("Warm solution must be an implementation of LS-SVM, not " + warmSolution.getClass());
        }
        double[] targets = dataSet.getTargetValues().arrayCopy();
        this.mainLoop(dataSet, (LSSVM)warmSolution, targets, parallel);
    }

    @Override
    public void train(RegressionDataSet dataSet, Regressor warmSolution) {
        this.train(dataSet, warmSolution, false);
    }

    @Override
    public void train(ClassificationDataSet dataSet, Classifier warmSolution, boolean parallel) {
        if (dataSet.getClassSize() != 2) {
            throw new FailedToFitException("LS-SVM only supports binary classification problems");
        }
        if (warmSolution != null && !(warmSolution instanceof LSSVM)) {
            throw new FailedToFitException("Warm solution must be an implementation of LS-SVM, not " + warmSolution.getClass());
        }
        double[] targets = new double[dataSet.size()];
        for (int i = 0; i < dataSet.size(); ++i) {
            targets[i] = dataSet.getDataPointCategory(i) * 2 - 1;
        }
        this.mainLoop(dataSet, (LSSVM)warmSolution, targets, parallel);
    }

    @Override
    public void train(ClassificationDataSet dataSet, Classifier warmSolution) {
        this.train(dataSet, warmSolution, false);
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    @Override
    public double regress(DataPoint data) {
        return this.kEvalSum(data.getNumericalValues()) - this.b;
    }

    @Override
    public void train(RegressionDataSet dataSet, boolean parallel) {
        this.train(dataSet, null, parallel);
    }

    @Override
    public LSSVM clone() {
        return new LSSVM(this);
    }

    private void mainLoop(DataSet dataSet, LSSVM warmSolution, double[] targets, boolean parallel) {
        try {
            ExecutorService ex = ParallelUtils.getNewExecutor(parallel);
            this.vecs = dataSet.getDataVectors();
            this.initializeVariables(targets, warmSolution, dataSet);
            boolean change = true;
            double dualityGap = this.computeDualityGap(true, parallel);
            int iter = 0;
            while (dualityGap > tol * this.dualObjective && change) {
                change = this.takeStep(this.i_up, this.i_low, ex, parallel);
                dualityGap = this.computeDualityGap(true, parallel);
                ++iter;
            }
            this.setCacheMode(null);
            this.setAlphas(this.alphas);
        }
        catch (InterruptedException interruptedException) {
            throw new FailedToFitException(interruptedException);
        }
        catch (ExecutionException executionException) {
            throw new FailedToFitException(executionException);
        }
    }

    public static Distribution guessC(DataSet d) {
        return PlattSMO.guessC(d);
    }
}

