/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.svm.extended;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import jsat.DataSet;
import jsat.classifiers.BaseUpdateableClassifier;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.DataPoint;
import jsat.distributions.Distribution;
import jsat.distributions.LogUniform;
import jsat.exceptions.FailedToFitException;
import jsat.linear.DenseVector;
import jsat.linear.ScaledVector;
import jsat.linear.Vec;
import jsat.linear.VecWithNorm;
import jsat.parameters.Parameterized;
import jsat.utils.DoubleList;
import jsat.utils.IndexTable;
import jsat.utils.IntList;

public class OnlineAMM
extends BaseUpdateableClassifier
implements Parameterized {
    private static final long serialVersionUID = 8291068484917637037L;
    protected List<Map<Integer, Vec>> weightMatrix;
    protected int[] nextID;
    protected double lambda;
    protected int k;
    protected double c;
    protected int time;
    protected int classBudget;
    public static final int DEFAULT_PRUNE_FREQUENCY = 10000;
    public static final double DEFAULT_PRUNE_CONSTANT = 10.0;
    public static final int DEFAULT_CLASS_BUDGET = 50;
    public static final double DEFAULT_REGULARIZER = 0.01;

    public OnlineAMM() {
        this(0.01);
    }

    public OnlineAMM(double lambda) {
        this(lambda, 50);
    }

    public OnlineAMM(double lambda, int classBudget) {
        this.setLambda(lambda);
        this.setClassBudget(classBudget);
        this.setPruneFrequency(10000);
        this.setC(10.0);
    }

    public OnlineAMM(OnlineAMM toCopy) {
        if (toCopy.weightMatrix != null) {
            this.weightMatrix = new ArrayList<Map<Integer, Vec>>(toCopy.weightMatrix.size());
            for (Map<Integer, Vec> oldW : toCopy.weightMatrix) {
                LinkedHashMap<Integer, Vec> newW = new LinkedHashMap<Integer, Vec>(oldW.size());
                for (Map.Entry<Integer, Vec> entry : oldW.entrySet()) {
                    newW.put(entry.getKey(), entry.getValue().clone());
                }
                this.weightMatrix.add(newW);
            }
            this.nextID = Arrays.copyOf(toCopy.nextID, toCopy.nextID.length);
        }
        this.time = toCopy.time;
        this.lambda = toCopy.lambda;
        this.k = toCopy.k;
        this.c = toCopy.c;
        this.classBudget = toCopy.classBudget;
        this.setEpochs(toCopy.getEpochs());
    }

    @Override
    public OnlineAMM clone() {
        return new OnlineAMM(this);
    }

    public void setLambda(double lambda) {
        if (lambda <= 0.0 || Double.isNaN(lambda) || Double.isInfinite(lambda)) {
            throw new IllegalArgumentException("Lambda must be positive, not " + lambda);
        }
        this.lambda = lambda;
    }

    public double getLambda() {
        return this.lambda;
    }

    public void setPruneFrequency(int frequency) {
        if (frequency < 1) {
            throw new IllegalArgumentException("Pruning frequency must be positive, not " + frequency);
        }
        this.k = frequency;
    }

    public int getPruneFrequency() {
        return this.k;
    }

    public void setC(double c) {
        if (c <= 0.0 || Double.isNaN(c) || Double.isInfinite(c)) {
            throw new IllegalArgumentException("C must be positive, not " + c);
        }
        this.c = c;
    }

    public double getC() {
        return this.c;
    }

    public void setClassBudget(int classBudget) {
        if (classBudget < 1) {
            throw new IllegalArgumentException("Number of hyperplanes must be positive, not " + classBudget);
        }
        this.classBudget = classBudget;
    }

    public int getClassBudget() {
        return this.classBudget;
    }

    @Override
    public void setUp(CategoricalData[] categoricalAttributes, int numericAttributes, CategoricalData predicting) {
        if (numericAttributes < 1) {
            throw new FailedToFitException("OnlineAMM requires numeric features to perform classification");
        }
        this.weightMatrix = new ArrayList<Map<Integer, Vec>>(predicting.getNumOfCategories());
        for (int i = 0; i < predicting.getNumOfCategories(); ++i) {
            this.weightMatrix.add(new LinkedHashMap());
        }
        this.nextID = new int[this.weightMatrix.size()];
        this.time = 1;
    }

    @Override
    public void update(DataPoint dataPoint, double weight, int y_t) {
        this.update(dataPoint, y_t, Integer.MIN_VALUE);
    }

    protected int update(DataPoint dataPoint, int y_t, int z_t) {
        double z_t_val;
        Vec x_t = dataPoint.getNumericalValues();
        if (z_t == Integer.MIN_VALUE || z_t > this.nextID[y_t]) {
            z_t_val = 0.0;
            z_t = -1;
            Map<Integer, Vec> w_yt = this.weightMatrix.get(y_t);
            for (Map.Entry<Integer, Vec> entry_yt : w_yt.entrySet()) {
                Vec v = entry_yt.getValue();
                double tmp = x_t.dot(v);
                if (!(tmp >= z_t_val)) continue;
                z_t = entry_yt.getKey();
                z_t_val = tmp;
            }
        } else {
            if (!this.weightMatrix.get(y_t).containsKey(z_t)) {
                return this.update(dataPoint, y_t, Integer.MIN_VALUE);
            }
            z_t_val = z_t == -1 ? 0.0 : this.weightMatrix.get(y_t).get(z_t).dot(x_t);
        }
        double eta = 1.0 / (this.lambda * (double)this.time++);
        int i_t = y_t > 0 ? 0 : 1;
        double i_t_val = 0.0;
        int j_t = -1;
        for (int k = 0; k < this.weightMatrix.size(); ++k) {
            if (k == y_t) continue;
            Map<Integer, Vec> w_k = this.weightMatrix.get(k);
            for (Map.Entry<Integer, Vec> entry_kj : w_k.entrySet()) {
                Vec w_kj = entry_kj.getValue();
                double tmp = x_t.dot(w_kj);
                if (!(tmp > i_t_val)) continue;
                i_t = k;
                j_t = entry_kj.getKey();
                i_t_val = tmp;
            }
        }
        boolean nonZeroLoss = 0.0 < 1.0 + i_t_val - z_t_val;
        for (int i = 0; i < this.weightMatrix.size(); ++i) {
            Map<Integer, Vec> w_i = this.weightMatrix.get(i);
            for (Map.Entry<Integer, Vec> w_entry_ij : w_i.entrySet()) {
                int j = w_entry_ij.getKey();
                Vec w_ij = w_entry_ij.getValue();
                w_ij.mutableMultiply(-(eta * this.lambda - 1.0));
                if (i == i_t && j == j_t && nonZeroLoss) {
                    w_ij.mutableSubtract(eta, x_t);
                    continue;
                }
                if (i != y_t || j != z_t || !nonZeroLoss) continue;
                w_ij.mutableAdd(eta, x_t);
            }
            if (i == i_t && j_t == -1 && nonZeroLoss && w_i.size() < this.classBudget) {
                double norm = x_t.pNorm(2.0);
                Vec v = new DenseVector(x_t);
                v = new VecWithNorm(v, norm);
                v = new ScaledVector(v);
                v.mutableMultiply(-eta);
                int n = i;
                int n2 = this.nextID[n];
                this.nextID[n] = n2 + 1;
                w_i.put(n2, v);
                continue;
            }
            if (i != y_t || z_t != -1 || !nonZeroLoss || w_i.size() >= this.classBudget) continue;
            double norm = x_t.pNorm(2.0);
            Vec v = new DenseVector(x_t);
            v = new VecWithNorm(v, norm);
            v = new ScaledVector(v);
            v.mutableMultiply(eta);
            int n = i;
            int n3 = this.nextID[n];
            this.nextID[n] = n3 + 1;
            w_i.put(n3, v);
            z_t = w_i.size() - 1;
        }
        if (this.time % this.k == 0) {
            int i;
            double norm;
            double threshold = this.c / ((double)(this.time - 1) * this.lambda);
            IntList classOwner = new IntList(this.weightMatrix.size());
            IntList vecID = new IntList(this.weightMatrix.size());
            DoubleList normVal = new DoubleList(this.weightMatrix.size());
            for (int i2 = 0; i2 < this.weightMatrix.size(); ++i2) {
                for (Map.Entry<Integer, Vec> entry : this.weightMatrix.get(i2).entrySet()) {
                    Vec v = entry.getValue();
                    classOwner.add(i2);
                    vecID.add(entry.getKey());
                    normVal.add(v.dot(v));
                }
            }
            IndexTable it = new IndexTable(normVal);
            for (int orderIndx = 0; orderIndx < normVal.size() && !((norm = normVal.get(i = it.index(orderIndx)).doubleValue()) >= threshold); threshold -= norm, ++orderIndx) {
                int classOf = classOwner.getI(i);
                this.weightMatrix.get(classOf).remove(vecID.getI(i));
            }
        }
        return z_t;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        Vec x = data.getNumericalValues();
        int k_indx = 0;
        double maxVal = Double.NEGATIVE_INFINITY;
        for (int k = 0; k < this.weightMatrix.size(); ++k) {
            for (Vec w_kj : this.weightMatrix.get(k).values()) {
                double tmp = x.dot(w_kj);
                if (!(tmp > maxVal)) continue;
                k_indx = k;
                maxVal = tmp;
            }
        }
        CategoricalResults cr = new CategoricalResults(this.weightMatrix.size());
        cr.setProb(k_indx, 1.0);
        return cr;
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    public static Distribution guessLambda(DataSet d) {
        return new LogUniform(1.0E-7, 0.01);
    }
}

