/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.svm;

import java.util.Arrays;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.DataPoint;
import jsat.classifiers.calibration.BinaryScoreClassifier;
import jsat.classifiers.svm.SupportVectorLearner;
import jsat.distributions.kernels.KernelTrick;
import jsat.distributions.kernels.NormalizedKernel;
import jsat.exceptions.UntrainedModelException;
import jsat.linear.Vec;
import jsat.utils.concurrent.AtomicDouble;
import jsat.utils.concurrent.ParallelUtils;

public class SVMnoBias
extends SupportVectorLearner
implements BinaryScoreClassifier {
    private double C = 1.0;
    private double tolerance = 0.001;
    protected short[] label;
    protected Vec weights;
    private double T_a;
    private double S_a;

    public SVMnoBias(KernelTrick kf) {
        super(kf, SupportVectorLearner.CacheMode.NONE);
    }

    public SVMnoBias(SVMnoBias toCopy) {
        super(toCopy);
        if (toCopy.weights != null) {
            this.weights = toCopy.weights.clone();
        }
        if (toCopy.label != null) {
            this.label = Arrays.copyOf(toCopy.label, toCopy.label.length);
        }
        this.C = toCopy.C;
        this.tolerance = toCopy.tolerance;
    }

    @Override
    public void setKernel(KernelTrick kernel) {
        if (kernel.normalized()) {
            super.setKernel(kernel);
        } else {
            super.setKernel(new NormalizedKernel(kernel));
        }
    }

    @Override
    public double getScore(DataPoint dp) {
        return this.kEvalSum(dp.getNumericalValues());
    }

    @Override
    public SVMnoBias clone() {
        return new SVMnoBias(this);
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        if (this.vecs == null) {
            throw new UntrainedModelException("Classifier has yet to be trained");
        }
        CategoricalResults cr = new CategoricalResults(2);
        double sum = this.getScore(data);
        if (sum > 0.0) {
            cr.setProb(1, 1.0);
        } else {
            cr.setProb(0, 1.0);
        }
        return cr;
    }

    @Override
    public void train(ClassificationDataSet dataSet, boolean parallel) {
        this.bookKeepingInit(dataSet);
        double[] nabla_W = this.procedure3_init();
        this.solver_1d(nabla_W, parallel);
        this.setCacheMode(null);
    }

    protected void train(ClassificationDataSet dataSet, double[] warm_start) {
        this.train(dataSet, warm_start, false);
    }

    protected void train(ClassificationDataSet dataSet, double[] warm_start, boolean parallel) {
        this.bookKeepingInit(dataSet);
        for (int i = 0; i < this.alphas.length; ++i) {
            this.alphas[i] = Math.abs(warm_start[i]);
        }
        double[] nabla_W = this.procedure4m_init(parallel);
        this.solver_1d(nabla_W, parallel);
        this.setCacheMode(null);
    }

    private void solver_1d(double[] nabla_W, boolean parallel) {
        int N = this.alphas.length;
        double lambda = 1.0 / (2.0 * this.C * (double)N);
        while (this.S_a > this.tolerance / (2.0 * lambda)) {
            double bestgain = -1.0;
            int i_max = -1;
            double best_delta = -1.0;
            for (int i = 0; i < N; ++i) {
                double a_star_i = Math.max(Math.min(this.weights.get(i) * this.C, nabla_W[i] + this.alphas[i]), 0.0);
                double delta = a_star_i - this.alphas[i];
                double gain = delta * (nabla_W[i] - delta / 2.0);
                if (!(gain >= bestgain)) continue;
                bestgain = gain;
                i_max = i;
                best_delta = delta;
            }
            int n = i_max;
            this.alphas[n] = this.alphas[n] + best_delta;
            if (this.alphas[i_max] + 1.0E-7 > this.weights.get(i_max) * this.C) {
                this.alphas[i_max] = this.weights.get(i_max) * this.C;
            } else if (this.alphas[i_max] - 1.0E-7 < 0.0) {
                this.alphas[i_max] = 0.0;
            }
            double delta = best_delta;
            int i = i_max;
            this.T_a -= best_delta * (2.0 * nabla_W[i_max] - 1.0 - best_delta);
            AtomicDouble E_a = new AtomicDouble(0.0);
            this.accessingRow(i);
            ParallelUtils.run(parallel, N, (start, end) -> {
                double Ea_delta = 0.0;
                for (int j = start; j < end; ++j) {
                    int n = j;
                    nabla_W[n] = nabla_W[n] - delta * (double)this.label[i] * (double)this.label[j] * this.kEval(i, j);
                    Ea_delta += this.weights.get(j) * this.C * Math.min(Math.max(0.0, nabla_W[j]), 2.0);
                }
                E_a.addAndGet(Ea_delta);
            });
            this.S_a = this.T_a + E_a.get();
        }
        this.accessingRow(-1);
        for (int i = 0; i < this.label.length; ++i) {
            int n = i;
            this.alphas[n] = this.alphas[n] * (double)this.label[i];
        }
    }

    private double[] procedure3_init() {
        int N = this.alphas.length;
        this.T_a = 0.0;
        this.S_a = 0.0;
        double[] nabla_W = new double[N];
        for (int i = 0; i < N; ++i) {
            nabla_W[i] = 1.0;
            this.S_a += this.weights.get(i) * this.C;
        }
        return nabla_W;
    }

    private double[] procedure4m_init(boolean parallel) {
        int N = this.alphas.length;
        this.T_a = 0.0;
        AtomicDouble E_a = new AtomicDouble(0.0);
        AtomicDouble T_a_accum = new AtomicDouble(0.0);
        double[] nabla_W = new double[N];
        ParallelUtils.run(parallel, N, (start, end) -> {
            double Ta_delta = 0.0;
            double Ea_delta = 0.0;
            for (int i = start; i < end; ++i) {
                nabla_W[i] = 1.0;
                double nabla_Wi_delta = 0.0;
                for (int j = 0; j < N; ++j) {
                    if (this.alphas[j] == 0.0) continue;
                    double k_ij = this.getCacheMode() == SupportVectorLearner.CacheMode.FULL ? this.kEval(i, j) : this.k(i, j);
                    nabla_Wi_delta -= this.alphas[j] * (double)this.label[i] * (double)this.label[j] * k_ij;
                }
                int n = i;
                nabla_W[n] = nabla_W[n] + nabla_Wi_delta;
                Ta_delta -= this.alphas[i] * nabla_W[i];
                Ea_delta += this.weights.get(i) * this.C * Math.min(Math.max(nabla_W[i], 0.0), 2.0);
            }
            E_a.addAndGet(Ea_delta);
            T_a_accum.addAndGet(Ta_delta);
        });
        this.T_a = T_a_accum.get();
        this.S_a = this.T_a + E_a.get();
        return nabla_W;
    }

    private void bookKeepingInit(ClassificationDataSet dataSet) {
        int N = dataSet.size();
        this.vecs = dataSet.getDataVectors();
        this.weights = dataSet.getDataWeights();
        this.label = new short[N];
        for (int i = 0; i < N; ++i) {
            this.label[i] = (short)(dataSet.getDataPointCategory(i) * 2 - 1);
        }
        this.setCacheMode(this.getCacheMode());
        this.alphas = new double[N];
    }

    @Override
    public boolean supportsWeightedData() {
        return true;
    }

    public void setC(double C2) {
        if (C2 <= 0.0) {
            throw new ArithmeticException("C must be a positive constant");
        }
        this.C = C2;
    }

    public double getC() {
        return this.C;
    }

    public void setTolerance(double tolerance) {
        this.tolerance = tolerance;
    }

    public double getTolerance() {
        return this.tolerance;
    }
}

