/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.linear.kernelized;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import jsat.DataSet;
import jsat.classifiers.BaseUpdateableClassifier;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.DataPoint;
import jsat.classifiers.calibration.BinaryScoreClassifier;
import jsat.distributions.Distribution;
import jsat.distributions.LogUniform;
import jsat.distributions.kernels.KernelTrick;
import jsat.exceptions.FailedToFitException;
import jsat.exceptions.UntrainedModelException;
import jsat.linear.Vec;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.DoubleList;

public class DUOL
extends BaseUpdateableClassifier
implements BinaryScoreClassifier,
Parameterized {
    private static final long serialVersionUID = -4751569462573287056L;
    @Parameter.ParameterHolder
    protected KernelTrick k;
    protected List<Vec> S;
    protected List<Double> f_s;
    protected List<Double> alphas;
    protected List<Double> accelCache;
    protected DoubleList kTmp;
    protected double rho = 0.0;
    protected double C = 10.0;

    public DUOL(KernelTrick k) {
        this.k = k;
        this.S = new ArrayList<Vec>();
        this.f_s = new DoubleList();
        this.alphas = new DoubleList();
    }

    protected DUOL(DUOL other) {
        this.k = other.k.clone();
        if (other.S != null) {
            this.S = new ArrayList<Vec>(other.S.size());
            for (Vec v : other.S) {
                this.S.add(v.clone());
            }
            this.f_s = new DoubleList(other.f_s);
            this.alphas = new DoubleList(other.alphas);
            if (other.accelCache != null) {
                this.accelCache = new DoubleList(other.accelCache);
            }
            if (other.kTmp != null) {
                this.kTmp = new DoubleList(other.kTmp);
            }
        }
        this.rho = other.rho;
        this.C = other.C;
    }

    @Override
    public DUOL clone() {
        return new DUOL(this);
    }

    public void setC(double C2) {
        if (Double.isNaN(C2) || C2 <= 0.0 || Double.isInfinite(C2)) {
            throw new IllegalArgumentException("C parameter must be in range (0, inf) not " + C2);
        }
        this.C = C2;
    }

    public double getC() {
        return this.C;
    }

    public void setRho(double rho) {
        this.rho = rho;
    }

    public double getRho() {
        return this.rho;
    }

    public void setKernel(KernelTrick k) {
        this.k = k;
    }

    public KernelTrick getKernel() {
        return this.k;
    }

    @Override
    public void setUp(CategoricalData[] categoricalAttributes, int numericAttributes, CategoricalData predicting) {
        if (numericAttributes <= 0) {
            throw new FailedToFitException("DUOL requires numeric features");
        }
        if (predicting.getNumOfCategories() != 2) {
            throw new FailedToFitException("DUOL supports only binnary classification");
        }
        this.S = new ArrayList<Vec>();
        this.f_s = new DoubleList();
        this.alphas = new DoubleList();
        this.accelCache = new DoubleList();
        this.kTmp = new DoubleList();
    }

    @Override
    public synchronized void update(DataPoint dataPoint, double weight, int targetClass) {
        List<Double> qi;
        double y_t = targetClass * 2 - 1;
        Vec x_t = dataPoint.getNumericalValues();
        double score = this.score(x_t, qi = this.k.getQueryInfo(x_t), true);
        double loss_t = Math.max(0.0, 1.0 - y_t * score);
        if (loss_t <= 0.0) {
            return;
        }
        int b = -1;
        double w_min = Double.POSITIVE_INFINITY;
        for (int i = 0; i < this.S.size(); ++i) {
            double tmp;
            if (!(this.f_s.get(i) <= 1.0) || !((tmp = Math.signum(this.alphas.get(i)) * y_t * this.kTmp.get(i)) <= w_min)) continue;
            w_min = tmp;
            b = i;
        }
        double k_t = this.k.eval(0, 0, Arrays.asList(x_t), qi);
        if (w_min <= -this.rho) {
            double gamma_b_delta;
            double gamma_t;
            double k_b = this.k.eval(b, b, this.S, this.accelCache);
            double k_tb = this.kTmp.get(b);
            double alpha_b = this.alphas.get(b);
            double w_tb = y_t * Math.signum(alpha_b) * k_tb;
            double gamma_hat_b = Math.abs(alpha_b);
            double loss_b = 1.0 - Math.signum(alpha_b) * this.f_s.get(b);
            double CmGhb = this.C - gamma_hat_b;
            if (k_t * this.C + w_tb * CmGhb - loss_t < 0.0 && k_b * CmGhb + w_tb * this.C - loss_b < 0.0) {
                gamma_t = this.C;
                gamma_b_delta = CmGhb;
            } else if ((w_tb * w_tb * this.C - w_tb * loss_b - k_t * k_b * this.C + k_b * loss_t) / k_b > 0.0 && this.isIn((loss_b - w_tb * this.C) / k_b, -gamma_hat_b, CmGhb)) {
                gamma_t = this.C;
                gamma_b_delta = (loss_b - w_tb * this.C) / k_b;
            } else if (this.isIn((loss_t - w_tb * CmGhb) / k_t, 0.0, this.C) && loss_b - k_b * CmGhb - w_tb * (loss_t - w_tb * CmGhb) / k_t > 0.0) {
                gamma_t = loss_t - w_tb * CmGhb / k_t;
                gamma_b_delta = CmGhb;
            } else {
                double denom = k_t * k_b - w_tb * w_tb;
                gamma_t = (k_b * loss_t - w_tb * loss_b) / denom;
                gamma_b_delta = (k_t * loss_b - w_tb * loss_t) / denom;
            }
            double gamma_b = gamma_hat_b + gamma_b_delta;
            this.S.add(x_t);
            this.accelCache.addAll(qi);
            this.kTmp.add(k_t);
            this.alphas.add(y_t * gamma_t);
            this.f_s.add(score);
            for (int i = 0; i < this.S.size(); ++i) {
                double y_i = Math.signum(this.alphas.get(i));
                this.f_s.set(i, this.f_s.get(i) + y_i * gamma_t * y_t * this.kTmp.get(i) + y_i * gamma_b_delta * Math.signum(alpha_b) * this.k.eval(i, b, this.S, this.accelCache));
            }
            this.alphas.set(b, Math.signum(alpha_b) * gamma_b);
        } else {
            double gamma_t = Math.min(this.C, loss_t / k_t);
            this.S.add(x_t);
            this.accelCache.addAll(qi);
            this.kTmp.add(k_t);
            this.alphas.add(y_t * gamma_t);
            this.f_s.add(score);
            for (int i = 0; i < this.S.size(); ++i) {
                double y_i = Math.signum(this.alphas.get(i));
                this.f_s.set(i, this.f_s.get(i) + y_i * gamma_t * y_t * this.kTmp.get(i));
            }
        }
    }

    private boolean isIn(double x, double a, double b) {
        return a <= x && x <= b;
    }

    private double score(Vec x, List<Double> qi, boolean store) {
        if (store) {
            this.kTmp.clear();
        }
        double score = 0.0;
        for (int i = 0; i < this.S.size(); ++i) {
            double tmp = this.k.eval(i, x, qi, this.S, this.accelCache);
            if (store) {
                this.kTmp.add(tmp);
            }
            score += this.alphas.get(i) * tmp;
        }
        return score;
    }

    private double score(Vec x, List<Double> qi) {
        return this.score(x, qi, false);
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        if (this.alphas == null) {
            throw new UntrainedModelException("Model has not yet been trained");
        }
        CategoricalResults cr = new CategoricalResults(2);
        double score = this.getScore(data);
        if (score < 0.0) {
            cr.setProb(0, 1.0);
        } else {
            cr.setProb(1, 1.0);
        }
        return cr;
    }

    @Override
    public double getScore(DataPoint dp) {
        Vec x = dp.getNumericalValues();
        return this.score(x, this.k.getQueryInfo(x));
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    public static Distribution guessC(DataSet d) {
        return new LogUniform(1.0E-4, 100000.0);
    }
}

