/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.svm;

import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import jsat.DataSet;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.calibration.BinaryScoreClassifier;
import jsat.classifiers.svm.SVMnoBias;
import jsat.classifiers.svm.SupportVectorLearner;
import jsat.clustering.kmeans.ElkanKernelKMeans;
import jsat.clustering.kmeans.KernelKMeans;
import jsat.distributions.kernels.KernelTrick;
import jsat.distributions.kernels.RBFKernel;
import jsat.exceptions.UntrainedModelException;
import jsat.linear.Vec;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.DoubleList;
import jsat.utils.IntList;
import jsat.utils.IntSet;
import jsat.utils.ListUtils;
import jsat.utils.concurrent.ParallelUtils;

public class DCSVM
extends SupportVectorLearner
implements Classifier,
Parameterized,
BinaryScoreClassifier {
    private double C = 1.0;
    private double tolerance = 0.001;
    private KernelKMeans clusters;
    private int m = 2000;
    private int l_max = 4;
    private int l_early = 3;
    private int k = 4;
    private Map<Integer, SVMnoBias> early_models;
    private long cache_size = 0L;

    public DCSVM(KernelTrick k) {
        super(k, SupportVectorLearner.CacheMode.ROWS);
        this.cache_size = Runtime.getRuntime().freeMemory() / 2L;
    }

    public DCSVM() {
        this(new RBFKernel());
    }

    public DCSVM(DCSVM toCopy) {
        super(toCopy);
        this.C = toCopy.C;
        this.tolerance = toCopy.tolerance;
        if (toCopy.clusters != null) {
            this.clusters = toCopy.clusters.clone();
        }
        this.cache_size = toCopy.cache_size;
        this.m = toCopy.m;
        this.l_early = toCopy.l_early;
        this.l_max = toCopy.l_max;
        this.k = toCopy.k;
        if (toCopy.early_models != null) {
            this.early_models = new ConcurrentHashMap<Integer, SVMnoBias>();
            for (Map.Entry<Integer, SVMnoBias> x : toCopy.early_models.entrySet()) {
                this.early_models.put(x.getKey(), x.getValue().clone());
            }
        }
    }

    public void setStartLevel(int l_max) {
        if (l_max < 0) {
            throw new IllegalArgumentException("l_max must be a non-negative integer, not " + l_max);
        }
        this.l_max = l_max;
    }

    public int getStartLevel() {
        return this.l_max;
    }

    public void setEndLevel(int l_early) {
        if (l_early < 0) {
            throw new IllegalArgumentException("l_early must be a non-negative integer, not " + l_early);
        }
        this.l_early = l_early;
    }

    public int getEndLevel() {
        return this.l_early;
    }

    public void setClusterSampleSize(int m) {
        if (m <= 0) {
            throw new IllegalArgumentException("Cluster Sample Size must be a positive integer, not " + m);
        }
        this.m = m;
    }

    public int getClusterSampleSize() {
        return this.m;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        CategoricalResults cr = new CategoricalResults(2);
        double sum = this.getScore(data);
        if (sum > 0.0) {
            cr.setProb(1, 1.0);
        } else {
            cr.setProb(0, 1.0);
        }
        return cr;
    }

    @Override
    public double getScore(DataPoint dp) {
        if (this.vecs == null) {
            throw new UntrainedModelException("Classifier has yet to be trained");
        }
        Vec x = dp.getNumericalValues();
        int c = this.early_models.size() > 1 ? this.clusters.findClosestCluster(x, this.getKernel().getQueryInfo(x)) : 0;
        return this.early_models.get(c).getScore(dp);
    }

    @Override
    public void train(ClassificationDataSet dataSet, boolean parallel) {
        int N = dataSet.size();
        this.vecs = dataSet.getDataVectors();
        this.early_models = new ConcurrentHashMap<Integer, SVMnoBias>();
        this.setCacheMode(SupportVectorLearner.CacheMode.NONE);
        this.alphas = new double[N];
        int[] group = new int[N];
        IntList indicies = new IntList();
        for (int l = this.l_max; l >= this.l_early; --l) {
            int[] sub_results;
            int i2;
            this.early_models.clear();
            ClassificationDataSet toCluster = new ClassificationDataSet(dataSet.getNumNumericalVars(), dataSet.getCategories(), dataSet.getPredicting());
            int k_l = (int)Math.pow(this.k, l);
            int M = N / k_l < 7 ? k_l * 7 : this.m;
            if (l == this.l_max) {
                ListUtils.addRange(indicies, 0, N, 1);
                Collections.shuffle(indicies);
                for (i2 = 0; i2 < Math.min(M, N); ++i2) {
                    toCluster.addDataPoint(dataSet.getDataPoint(i2), dataSet.getDataPointCategory(i2));
                }
            } else {
                indicies.clear();
                for (i2 = 0; i2 < N; ++i2) {
                    if (this.alphas[i2] == 0.0) continue;
                    indicies.add(i2);
                }
                Collections.shuffle(indicies);
                for (i2 = 0; i2 < Math.min(M, indicies.size()); ++i2) {
                    toCluster.addDataPoint(dataSet.getDataPoint(i2), dataSet.getDataPointCategory(i2));
                }
            }
            this.clusters = new ElkanKernelKMeans(this.getKernel());
            this.clusters.setMaximumIterations(100);
            k_l = Math.min(k_l, toCluster.size() / 2);
            if (k_l <= 1) {
                sub_results = new int[N];
                indicies.clear();
                ListUtils.addRange(indicies, 0, N, 1);
            } else {
                sub_results = this.clusters.cluster((DataSet)toCluster, k_l, parallel, (int[])null);
            }
            Arrays.fill(group, -1);
            IntSet found_clusters = new IntSet(k_l);
            for (int i3 = 0; i3 < sub_results.length; ++i3) {
                group[indicies.get((int)i3).intValue()] = sub_results[i3];
                found_clusters.add(Integer.valueOf(sub_results[i3]));
            }
            ParallelUtils.run(parallel, N, i -> {
                if (group[i] >= 0) {
                    return;
                }
                List qi = null;
                if (this.accelCache != null) {
                    int multiplier = this.accelCache.size() / N;
                    qi = this.accelCache.subList(i * multiplier, i * multiplier + multiplier);
                }
                group[i] = this.clusters.findClosestCluster((Vec)this.vecs.get(i), qi);
            });
            Iterator iterator = found_clusters.iterator();
            while (iterator.hasNext()) {
                int c = (Integer)iterator.next();
                ClassificationDataSet V_c = new ClassificationDataSet(dataSet.getNumNumericalVars(), dataSet.getCategories(), dataSet.getPredicting());
                DoubleList V_alphas = new DoubleList();
                IntList orig_index = new IntList();
                for (int i4 = 0; i4 < N; ++i4) {
                    if (group[i4] != c) continue;
                    V_c.addDataPoint(dataSet.getDataPoint(i4), dataSet.getDataPointCategory(i4));
                    V_alphas.add(Math.abs(this.alphas[i4]));
                    orig_index.add(i4);
                }
                SVMnoBias svm = new SVMnoBias(this.getKernel());
                if (this.cache_size > 0L) {
                    svm.setCacheSize(V_alphas.size(), this.cache_size);
                } else {
                    svm.setCacheMode(SupportVectorLearner.CacheMode.NONE);
                }
                if (l == this.l_max) {
                    svm.train(V_c, parallel);
                } else {
                    svm.train(V_c, V_alphas.getBackingArray(), parallel);
                }
                this.early_models.put(c, svm);
                for (int i5 = 0; i5 < orig_index.size(); ++i5) {
                    this.alphas[orig_index.get((int)i5).intValue()] = svm.alphas[i5];
                }
            }
        }
        if (this.l_early == 0) {
            SVMnoBias svm = new SVMnoBias(this.getKernel());
            if (this.cache_size > 0L) {
                svm.setCacheSize(dataSet.size(), this.cache_size);
            } else {
                svm.setCacheMode(SupportVectorLearner.CacheMode.NONE);
            }
            svm.train(dataSet, Arrays.copyOf(this.alphas, this.alphas.length), parallel);
            this.early_models.clear();
            this.early_models.put(0, svm);
            for (int i6 = 0; i6 < N; ++i6) {
                this.alphas[i6] = svm.alphas[i6];
            }
        }
    }

    @Override
    public boolean supportsWeightedData() {
        return true;
    }

    @Override
    public DCSVM clone() {
        return new DCSVM(this);
    }

    @Parameter.WarmParameter(prefLowToHigh=true)
    public void setC(double C2) {
        if (C2 <= 0.0) {
            throw new ArithmeticException("C must be a positive constant");
        }
        this.C = C2;
    }

    public double getC() {
        return this.C;
    }
}

