/*
 * Decompiled with CFR 0.152.
 */
package jsat.clustering;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Random;
import java.util.concurrent.atomic.AtomicInteger;
import jsat.DataSet;
import jsat.SimpleDataSet;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.DataPoint;
import jsat.linear.DenseVector;
import jsat.linear.MatrixStatistics;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.vectorcollection.VPTreeMV;
import jsat.utils.DoubleList;
import jsat.utils.IndexTable;
import jsat.utils.IntList;
import jsat.utils.IntSet;
import jsat.utils.concurrent.ParallelUtils;

public class SeedSelectionMethods {
    private SeedSelectionMethods() {
    }

    public static List<Vec> selectIntialPoints(DataSet d, int k, DistanceMetric dm, Random rand, SeedSelection selectionMethod) {
        return SeedSelectionMethods.selectIntialPoints(d, k, dm, null, rand, selectionMethod);
    }

    public static List<Vec> selectIntialPoints(DataSet d, int k, DistanceMetric dm, List<Double> accelCache, Random rand, SeedSelection selectionMethod) {
        int[] indicies = new int[k];
        SeedSelectionMethods.selectIntialPoints(d, indicies, dm, accelCache, rand, selectionMethod, false);
        ArrayList<Vec> vecs = new ArrayList<Vec>(k);
        int[] nArray = indicies;
        int n = nArray.length;
        for (int i = 0; i < n; ++i) {
            Integer i2 = nArray[i];
            vecs.add(d.getDataPoint(i2).getNumericalValues().clone());
        }
        return vecs;
    }

    public static List<Vec> selectIntialPoints(DataSet d, int k, DistanceMetric dm, Random rand, SeedSelection selectionMethod, boolean parallel) {
        return SeedSelectionMethods.selectIntialPoints(d, k, dm, null, rand, selectionMethod, parallel);
    }

    public static List<Vec> selectIntialPoints(DataSet d, int k, DistanceMetric dm, List<Double> accelCache, Random rand, SeedSelection selectionMethod, boolean parallel) {
        int[] indicies = new int[k];
        SeedSelectionMethods.selectIntialPoints(d, indicies, dm, accelCache, rand, selectionMethod, parallel);
        ArrayList<Vec> vecs = new ArrayList<Vec>(k);
        int[] nArray = indicies;
        int n = nArray.length;
        for (int i = 0; i < n; ++i) {
            Integer i2 = nArray[i];
            vecs.add(d.getDataPoint(i2).getNumericalValues().clone());
        }
        return vecs;
    }

    public static void selectIntialPoints(DataSet d, int[] indices, DistanceMetric dm, Random rand, SeedSelection selectionMethod) {
        SeedSelectionMethods.selectIntialPoints(d, indices, dm, null, rand, selectionMethod);
    }

    public static void selectIntialPoints(DataSet d, int[] indices, DistanceMetric dm, List<Double> accelCache, Random rand, SeedSelection selectionMethod) {
        SeedSelectionMethods.selectIntialPoints(d, indices, dm, accelCache, rand, selectionMethod, false);
    }

    public static void selectIntialPoints(DataSet d, int[] indices, DistanceMetric dm, Random rand, SeedSelection selectionMethod, boolean parallel) {
        SeedSelectionMethods.selectIntialPoints(d, indices, dm, null, rand, selectionMethod, parallel);
    }

    public static void selectIntialPoints(DataSet d, int[] indices, DistanceMetric dm, List<Double> accelCache, Random rand, SeedSelection selectionMethod, boolean parallel) {
        int k = indices.length;
        if (null != selectionMethod) {
            switch (selectionMethod) {
                case RANDOM: {
                    IntSet indecies = new IntSet(k);
                    while (indecies.size() != k) {
                        indecies.add(Integer.valueOf(rand.nextInt(d.size())));
                    }
                    int j = 0;
                    for (Integer i : indecies) {
                        indices[j++] = i;
                    }
                    break;
                }
                case KPP_TIA: {
                    SeedSelectionMethods.kppSelectionTIA(indices, rand, d, k, dm, accelCache, parallel);
                    break;
                }
                case KPP: {
                    SeedSelectionMethods.kppSelection(indices, rand, d, k, dm, accelCache, parallel);
                    break;
                }
                case KBB_TIA: {
                    SeedSelectionMethods.kbbSelectionTIA(indices, rand, d, k, dm, accelCache, parallel);
                    break;
                }
                case KBB: {
                    SeedSelectionMethods.kbbSelection(indices, rand, d, k, dm, accelCache, parallel);
                    break;
                }
                case FARTHEST_FIRST: {
                    SeedSelectionMethods.ffSelection(indices, rand, d, k, dm, accelCache, parallel);
                    break;
                }
                case MEAN_QUANTILES: {
                    SeedSelectionMethods.mqSelection(indices, d, k, dm, accelCache, parallel);
                    break;
                }
            }
        }
    }

    private static void kppSelection(int[] indices, Random rand, DataSet d, int k, DistanceMetric dm, List<Double> accelCache) {
        SeedSelectionMethods.kppSelection(indices, rand, d, k, dm, accelCache, false);
    }

    private static void kppSelection(int[] indices, Random rand, DataSet d, int k, DistanceMetric dm, List<Double> accelCache, boolean parallel) {
        indices[0] = rand.nextInt(d.size());
        Vec w = d.getDataWeights();
        double[] closestDist = new double[d.size()];
        Arrays.fill(closestDist, Double.POSITIVE_INFINITY);
        List<Vec> X = d.getDataVectors();
        for (int j = 1; j < k; ++j) {
            int newMeanIndx = indices[j - 1];
            double sqrdDistSum = ParallelUtils.run(parallel, X.size(), (start, end) -> {
                double partial_sqrd_dist = 0.0;
                for (int i = start; i < end; ++i) {
                    double newDist = dm.dist(newMeanIndx, i, (List<? extends Vec>)X, accelCache);
                    if ((newDist *= newDist) < closestDist[i]) {
                        closestDist[i] = newDist;
                    }
                    partial_sqrd_dist += closestDist[i] * w.get(i);
                }
                return partial_sqrd_dist;
            }, (t, u) -> t + u);
            if (sqrdDistSum <= 1.0E-6) {
                IntSet ind = new IntSet();
                for (int i = 0; i < j; ++i) {
                    ind.add(Integer.valueOf(indices[i]));
                }
                while (ind.size() < k) {
                    ind.add(Integer.valueOf(rand.nextInt(closestDist.length)));
                }
                int pos = 0;
                Iterator iterator = ind.iterator();
                while (iterator.hasNext()) {
                    int i = (Integer)iterator.next();
                    indices[pos++] = i;
                }
                return;
            }
            double rndX = rand.nextDouble() * sqrdDistSum;
            int i = 0;
            for (double searchSum = closestDist[0] * w.get(0); searchSum < rndX && i < d.size() - 1; searchSum += closestDist[++i] * w.get(i)) {
            }
            indices[j] = i;
        }
    }

    private static void kppSelectionTIA(int[] indices, Random rand, DataSet d, int k, DistanceMetric dm, List<Double> accelCache, boolean parallel) {
        double[] closestDist = new double[d.size()];
        Arrays.fill(closestDist, Double.POSITIVE_INFINITY);
        int[] closest_mean = new int[d.size()];
        Arrays.fill(closest_mean, 0);
        Vec w = d.getDataWeights();
        double[] expo_sample = new double[d.size()];
        indices[0] = 0;
        for (int i = 0; i < d.size(); ++i) {
            double p = rand.nextDouble();
            expo_sample[i] = -Math.log(1.0 - p) / w.get(i);
            if (!(expo_sample[i] < expo_sample[indices[0]])) continue;
            indices[0] = i;
        }
        double[] sample_weight = new double[d.size()];
        PriorityQueue<Integer> nextSample = new PriorityQueue<Integer>(expo_sample.length, (a, b) -> Double.compare(sample_weight[a], sample_weight[b]));
        IntList dirtyItemsToFix = new IntList();
        boolean[] dirty = new boolean[d.size()];
        Arrays.fill(dirty, false);
        closestDist[indices[0]] = 0.0;
        List<Vec> X = d.getDataVectors();
        double[] gamma = new double[k];
        Arrays.fill(gamma, Double.MAX_VALUE);
        double prev_partial = 0.0;
        for (int j = 1; j < k; ++j) {
            int next_indx;
            int jj = j;
            int newMeanIndx = indices[j - 1];
            double sqrdDistSum = ParallelUtils.run(parallel, X.size(), (start, end) -> {
                double partial_sqrd_dist = 0.0;
                for (int i = start; i < end; ++i) {
                    if (!(gamma[closest_mean[i]] < 4.0 * closestDist[i])) continue;
                    double newDist = dm.dist(newMeanIndx, i, (List<? extends Vec>)X, accelCache);
                    if (!((newDist *= newDist) < closestDist[i])) continue;
                    if (jj > 1) {
                        partial_sqrd_dist -= closestDist[i] * w.get(i);
                        dirty[i] = true;
                    } else {
                        sample_weight[i] = expo_sample[i] / newDist;
                        nextSample.add(i);
                    }
                    closest_mean[i] = jj - 1;
                    closestDist[i] = newDist;
                    partial_sqrd_dist += closestDist[i] * w.get(i);
                }
                return partial_sqrd_dist;
            }, (t, u) -> t + u);
            if (prev_partial != 0.0) {
                sqrdDistSum = prev_partial + sqrdDistSum;
            }
            prev_partial = sqrdDistSum;
            if (sqrdDistSum <= 1.0E-6) {
                IntSet ind = new IntSet();
                for (int i = 0; i < j; ++i) {
                    ind.add(Integer.valueOf(indices[i]));
                }
                while (ind.size() < k) {
                    ind.add(Integer.valueOf(rand.nextInt(closestDist.length)));
                }
                int pos = 0;
                Iterator iterator = ind.iterator();
                while (iterator.hasNext()) {
                    int i = (Integer)iterator.next();
                    indices[pos++] = i;
                }
                return;
            }
            int tries = 0;
            while (!nextSample.isEmpty() && dirty[(Integer)nextSample.peek()]) {
                dirtyItemsToFix.add((Integer)nextSample.poll());
            }
            Iterator pos = dirtyItemsToFix.iterator();
            while (pos.hasNext()) {
                int i = (Integer)pos.next();
                sample_weight[i] = expo_sample[i] / closestDist[i];
            }
            nextSample.addAll(dirtyItemsToFix);
            dirtyItemsToFix.clear();
            while (true) {
                ++tries;
                next_indx = (Integer)nextSample.poll();
                if (!dirty[next_indx]) break;
                sample_weight[next_indx] = expo_sample[next_indx] / closestDist[next_indx];
                dirty[next_indx] = false;
                nextSample.add(next_indx);
            }
            indices[j] = next_indx;
            if (j + 1 >= k) continue;
            ParallelUtils.run(parallel, j, (k_prev, end) -> {
                while (k_prev < end) {
                    gamma[k_prev] = Math.pow(dm.dist(indices[k_prev], indices[jj], (List<? extends Vec>)X, accelCache), 2.0);
                    ++k_prev;
                }
            });
        }
    }

    private static void kbbSelection(int[] indices, Random rand, DataSet d, int k, DistanceMetric dm, List<Double> accelCache, boolean parallel) {
        int trials = 5;
        int oversample = 2 * k;
        int[] assigned_too = new int[d.size()];
        IntList C2 = new IntList(trials * oversample);
        C2.add(rand.nextInt(d.size()));
        Vec w = d.getDataWeights();
        double[] closestDist = new double[d.size()];
        Arrays.fill(closestDist, Double.POSITIVE_INFINITY);
        List<Vec> X = d.getDataVectors();
        double sqrdDistSum = ParallelUtils.run(parallel, X.size(), (start, end) -> {
            double partial_sqrd_dist = 0.0;
            for (int i = start; i < end; ++i) {
                double newDist = dm.dist(C2.getI(0), i, (List<? extends Vec>)X, accelCache);
                if ((newDist *= newDist) < closestDist[i]) {
                    closestDist[i] = newDist;
                }
                partial_sqrd_dist += closestDist[i] * w.get(i);
            }
            return partial_sqrd_dist;
        }, (z, u) -> z + u);
        for (int t = 0; t < trials; ++t) {
            int orig_size = C2.size();
            for (int i = 0; i < X.size(); ++i) {
                if (!(w.get(i) * (double)oversample * closestDist[i] / sqrdDistSum > rand.nextDouble())) continue;
                C2.add(i);
            }
            sqrdDistSum = ParallelUtils.run(parallel, X.size(), (start, end) -> {
                double partial_sqrd_dist = 0.0;
                for (int i = start; i < end; ++i) {
                    if (closestDist[i] == 0.0) continue;
                    for (int j = orig_size; j < C2.size(); ++j) {
                        double newDist = dm.dist((int)C2.get(j), i, (List<? extends Vec>)X, accelCache);
                        if (!((newDist *= newDist) < closestDist[i])) continue;
                        closestDist[i] = newDist;
                        assigned_too[i] = j;
                    }
                    partial_sqrd_dist += closestDist[i] * w.get(i);
                }
                return partial_sqrd_dist;
            }, (z, u) -> z + u);
        }
        DenseVector weights = new DenseVector(C2.size());
        for (int i = 0; i < X.size(); ++i) {
            weights.increment(assigned_too[i], w.get(i));
        }
        SimpleDataSet sds = new SimpleDataSet(d.getNumNumericalVars(), new CategoricalData[0]);
        Iterator i = C2.iterator();
        while (i.hasNext()) {
            int j = (Integer)i.next();
            sds.add(new DataPoint(X.get(j)));
            sds.setWeight(sds.size() - 1, ((Vec)weights).get(sds.size() - 1));
        }
        SeedSelectionMethods.kppSelection(indices, rand, sds, k, dm, dm.getAccelerationCache(sds.getDataVectors(), parallel), parallel);
        for (int i2 = 0; i2 < k; ++i2) {
            indices[i2] = C2.getI(indices[i2]);
        }
    }

    private static void kbbSelectionTIA(int[] indices, Random rand, DataSet d, int k, DistanceMetric dm, List<Double> accelCache, boolean parallel) {
        Object to_assign;
        int trials = 5;
        int oversample = 2 * k;
        int[] assigned_too = new int[d.size()];
        IntList C2 = new IntList(trials * oversample);
        C2.add(rand.nextInt(d.size()));
        Vec w = d.getDataWeights();
        double[] closestDist = new double[d.size()];
        Arrays.fill(closestDist, Double.POSITIVE_INFINITY);
        List<Vec> X = d.getDataVectors();
        double sqrdDistSum = ParallelUtils.run(parallel, X.size(), (start, end) -> {
            double partial_sqrd_dist = 0.0;
            for (int i = start; i < end; ++i) {
                double newDist = dm.dist(C2.getI(0), i, (List<? extends Vec>)X, accelCache);
                if ((newDist *= newDist) < closestDist[i]) {
                    closestDist[i] = newDist;
                }
                partial_sqrd_dist += closestDist[i] * w.get(i);
            }
            return partial_sqrd_dist;
        }, (z, u) -> z + u);
        for (int t = 0; t < trials; ++t) {
            int orig_size = C2.size();
            for (int i = 0; i < X.size(); ++i) {
                if (!(w.get(i) * (double)oversample * closestDist[i] / sqrdDistSum > rand.nextDouble())) continue;
                C2.add(i);
            }
            to_assign = C2.subList(orig_size, C2.size());
            ArrayList<Vec> X_new_means = new ArrayList<Vec>(to_assign.size());
            Iterator iterator = to_assign.iterator();
            while (iterator.hasNext()) {
                int j = (Integer)iterator.next();
                X_new_means.add(X.get(j));
            }
            VPTreeMV vp = new VPTreeMV(X_new_means, dm, parallel);
            sqrdDistSum = ParallelUtils.run(parallel, X.size(), (start, end) -> {
                double partial_sqrd_dist = 0.0;
                IntList neighbors = new IntList();
                DoubleList distances = new DoubleList();
                for (int i = start; i < end; ++i) {
                    if (closestDist[i] == 0.0) continue;
                    neighbors.clear();
                    distances.clear();
                    vp.search((Vec)X.get(i), 1, Math.sqrt(closestDist[i]), neighbors, distances);
                    if (distances.isEmpty()) continue;
                    double newDist = distances.getD(0);
                    if ((newDist *= newDist) < closestDist[i]) {
                        closestDist[i] = newDist;
                        assigned_too[i] = orig_size + neighbors.getI(0);
                    }
                    partial_sqrd_dist += closestDist[i] * w.get(i);
                }
                return partial_sqrd_dist;
            }, (z, u) -> z + u);
        }
        DenseVector weights = new DenseVector(C2.size());
        for (int i = 0; i < X.size(); ++i) {
            weights.increment(assigned_too[i], w.get(i));
        }
        SimpleDataSet sds = new SimpleDataSet(d.getNumNumericalVars(), new CategoricalData[0]);
        to_assign = C2.iterator();
        while (to_assign.hasNext()) {
            int j = (Integer)to_assign.next();
            sds.add(new DataPoint(X.get(j)));
            sds.setWeight(sds.size() - 1, ((Vec)weights).get(sds.size() - 1));
        }
        SeedSelectionMethods.kppSelectionTIA(indices, rand, sds, k, dm, dm.getAccelerationCache(sds.getDataVectors(), parallel), parallel);
        for (int i = 0; i < k; ++i) {
            indices[i] = C2.getI(indices[i]);
        }
    }

    private static void ffSelection(int[] indices, Random rand, DataSet d, int k, DistanceMetric dm, List<Double> accelCache, boolean parallel) {
        indices[0] = rand.nextInt(d.size());
        double[] closestDist = new double[d.size()];
        Arrays.fill(closestDist, Double.POSITIVE_INFINITY);
        List<Vec> X = d.getDataVectors();
        for (int j = 1; j < k; ++j) {
            int newMeanIndx = indices[j - 1];
            AtomicInteger maxDistIndx = new AtomicInteger(0);
            ParallelUtils.run(parallel, d.size(), (start, end) -> {
                double maxDist = Double.NEGATIVE_INFINITY;
                int max = indices[0];
                for (int i = start; i < end; ++i) {
                    double newDist = dm.dist(newMeanIndx, i, (List<? extends Vec>)X, accelCache);
                    closestDist[i] = Math.min(newDist, closestDist[i]);
                    if (!(closestDist[i] > maxDist)) continue;
                    maxDist = closestDist[i];
                    max = i;
                }
                AtomicInteger atomicInteger = maxDistIndx;
                synchronized (atomicInteger) {
                    if (closestDist[max] > closestDist[maxDistIndx.get()]) {
                        maxDistIndx.set(max);
                    }
                }
            });
            indices[j] = maxDistIndx.get();
        }
    }

    private static void mqSelection(int[] indices, DataSet d, int k, DistanceMetric dm, List<Double> accelCache, boolean parallel) {
        double[] meanDist = new double[d.size()];
        Vec newMean = MatrixStatistics.meanVector(d);
        List<Double> meanQI = dm.getQueryInfo(newMean);
        List<Vec> X = d.getDataVectors();
        ParallelUtils.run(parallel, d.size(), (start, end) -> {
            for (int i = start; i < end; ++i) {
                meanDist[i] = dm.dist(i, newMean, meanQI, X, accelCache);
            }
        });
        IndexTable indxTbl = new IndexTable(meanDist);
        for (int l = 0; l < k; ++l) {
            indices[l] = indxTbl.index(l * d.size() / k);
        }
    }

    public static enum SeedSelection {
        RANDOM,
        KPP,
        KPP_TIA,
        KBB,
        KBB_TIA,
        FARTHEST_FIRST,
        MEAN_QUANTILES;

    }
}

