/*
 * Decompiled with CFR 0.152.
 */
package jsat.clustering.kmeans;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.DataSet;
import jsat.clustering.ClusterFailureException;
import jsat.clustering.SeedSelectionMethods;
import jsat.clustering.kmeans.KMeans;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.DenseSparseMetric;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.linear.distancemetrics.TrainableDistanceMetric;
import jsat.utils.SystemInfo;
import jsat.utils.concurrent.AtomicDoubleArray;
import jsat.utils.concurrent.ParallelUtils;
import jsat.utils.random.RandomUtil;

public class ElkanKMeans
extends KMeans {
    private static final long serialVersionUID = -1629432283103273051L;
    private DenseSparseMetric dmds;
    private boolean useDenseSparse = false;

    public ElkanKMeans(DistanceMetric dm, Random rand, SeedSelectionMethods.SeedSelection seedSelection) {
        super(dm, seedSelection, rand);
        if (!dm.isSubadditive()) {
            throw new ClusterFailureException("KMeans implementation requires the triangle inequality");
        }
    }

    public ElkanKMeans(DistanceMetric dm, Random rand) {
        this(dm, rand, DEFAULT_SEED_SELECTION);
    }

    public ElkanKMeans(DistanceMetric dm) {
        this(dm, RandomUtil.getRandom());
    }

    public ElkanKMeans() {
        this(new EuclideanDistance());
    }

    public ElkanKMeans(ElkanKMeans toCopy) {
        super(toCopy);
        if (toCopy.dmds != null) {
            this.dmds = (DenseSparseMetric)toCopy.dmds.clone();
        }
        this.useDenseSparse = toCopy.useDenseSparse;
    }

    public void setUseDenseSparse(boolean useDenseSparse) {
        this.useDenseSparse = useDenseSparse;
    }

    public boolean isUseDenseSparse() {
        return this.useDenseSparse;
    }

    @Override
    protected double cluster(DataSet dataSet, List<Double> accelCache, final int k, List<Vec> means, int[] assignment, boolean exactTotal, boolean parallel, boolean returnError, Vec dataPointWeights) {
        try {
            int N = dataSet.size();
            final int D2 = dataSet.getNumNumericalVars();
            if (N < k) {
                throw new ClusterFailureException("Fewer data points then desired clusters, decrease cluster size");
            }
            Vec W = dataPointWeights == null ? dataSet.getDataWeights() : dataPointWeights;
            TrainableDistanceMetric.trainIfNeeded(this.dm, dataSet);
            List<Vec> X = dataSet.getDataVectors();
            ArrayList<List<Double>> meanQIs = new ArrayList<List<Double>>(k);
            List<Double> distAccelCache = accelCache == null ? this.dm.getAccelerationCache(X, parallel) : accelCache;
            if (means.size() != k) {
                means.clear();
                means.addAll(SeedSelectionMethods.selectIntialPoints(dataSet, k, this.dm, distAccelCache, this.rand, this.seedSelection, parallel));
            }
            for (int i = 0; i < means.size(); ++i) {
                if (!means.get(i).isSparse()) continue;
                means.set(i, new DenseVector(means.get(i)));
            }
            double[][] lowerBound = new double[N][k];
            double[] upperBound = new double[N];
            double[][] centroidSelfDistances = new double[k][k];
            double[] sC = new double[k];
            this.calculateCentroidDistances(k, centroidSelfDistances, means, sC, null, parallel);
            AtomicDoubleArray meanCount = new AtomicDoubleArray(k);
            Vec[] oldMeans = new Vec[k];
            Vec[] meanSums = new Vec[k];
            for (int i = 0; i < k; ++i) {
                oldMeans[i] = means.get(i).clone();
                if (this.dm.supportsAcceleration()) {
                    meanQIs.add(this.dm.getQueryInfo(means.get(i)));
                } else {
                    meanQIs.add(Collections.EMPTY_LIST);
                }
                meanSums[i] = new DenseVector(D2);
            }
            if (this.dm instanceof DenseSparseMetric && this.useDenseSparse) {
                this.dmds = (DenseSparseMetric)this.dm;
            }
            double[] meanSummaryConsts = this.dmds != null ? new double[means.size()] : null;
            int atLeast = 2;
            AtomicBoolean changeOccurred = new AtomicBoolean(true);
            boolean[] r = new boolean[N];
            ThreadLocal<Vec[]> localDeltas = new ThreadLocal<Vec[]>(){

                @Override
                protected Vec[] initialValue() {
                    Vec[] toRet = new Vec[k];
                    for (int i = 0; i < toRet.length; ++i) {
                        toRet[i] = new DenseVector(D2);
                    }
                    return toRet;
                }
            };
            this.initialClusterSetUp(k, N, X, means, lowerBound, upperBound, centroidSelfDistances, assignment, meanCount, meanSums, distAccelCache, meanQIs, localDeltas, parallel, W);
            int iterLimit = this.MaxIterLimit;
            while ((changeOccurred.get() || atLeast > 0) && iterLimit-- >= 0) {
                --atLeast;
                changeOccurred.set(false);
                if (iterLimit < this.MaxIterLimit - 1) {
                    this.calculateCentroidDistances(k, centroidSelfDistances, means, sC, meanSummaryConsts, parallel);
                }
                CountDownLatch latch = new CountDownLatch(SystemInfo.LogicalCores);
                ParallelUtils.run(parallel, N, q -> {
                    if (upperBound[q] <= sC[assignment[q]]) {
                        return;
                    }
                    Vec v = (Vec)X.get(q);
                    for (int c = 0; c < k; ++c) {
                        if (c == assignment[q] || !(upperBound[q] > lowerBound[q][c]) || !(upperBound[q] > centroidSelfDistances[assignment[q]][c] * 0.5)) continue;
                        this.step3aBoundsUpdate(X, r, q, v, means, assignment, upperBound, lowerBound, meanSummaryConsts, distAccelCache, meanQIs);
                        this.step3bUpdate(X, upperBound, q, lowerBound, c, centroidSelfDistances, assignment, v, means, localDeltas, meanCount, changeOccurred, meanSummaryConsts, distAccelCache, meanQIs, W);
                    }
                    this.step4UpdateCentroids(meanSums, localDeltas);
                });
                this.step5_6_distanceMovedBoundsUpdate(k, oldMeans, means, meanSums, meanCount, N, lowerBound, upperBound, assignment, r, meanQIs, parallel);
            }
            double totalDistance = 0.0;
            if (returnError) {
                this.nearestCentroidDist = (double[])(this.saveCentroidDistance ? new double[N] : null);
                if (exactTotal) {
                    for (int i = 0; i < N; ++i) {
                        double dist = this.dm.dist(i, means.get(assignment[i]), (List)meanQIs.get(assignment[i]), X, distAccelCache);
                        totalDistance += Math.pow(dist, 2.0);
                        if (!this.saveCentroidDistance) continue;
                        this.nearestCentroidDist[i] = dist;
                    }
                } else {
                    for (int i = 0; i < N; ++i) {
                        totalDistance += Math.pow(upperBound[i], 2.0);
                        if (!this.saveCentroidDistance) continue;
                        this.nearestCentroidDist[i] = upperBound[i];
                    }
                }
            }
            return totalDistance;
        }
        catch (Exception ex) {
            Logger.getLogger(ElkanKMeans.class.getName()).log(Level.SEVERE, null, ex);
            return Double.MAX_VALUE;
        }
    }

    private void initialClusterSetUp(int k, int N, List<Vec> dataSet, List<Vec> means, double[][] lowerBound, double[] upperBound, double[][] centroidSelfDistances, int[] assignment, AtomicDoubleArray meanCount, Vec[] meanSums, List<Double> distAccelCache, List<List<Double>> meanQIs, ThreadLocal<Vec[]> localDeltas, boolean parallel, Vec W) {
        ParallelUtils.run(parallel, N, (from, to) -> {
            Vec[] deltas = (Vec[])localDeltas.get();
            boolean[] skip = new boolean[k];
            for (int q = from; q < to; ++q) {
                Vec v = (Vec)dataSet.get(q);
                double minDistance = Double.MAX_VALUE;
                int index = -1;
                Arrays.fill(skip, false);
                for (int i = 0; i < k; ++i) {
                    double d;
                    if (skip[i]) continue;
                    lowerBound[q][i] = d = this.dm.dist(q, (Vec)means.get(i), (List)meanQIs.get(i), dataSet, distAccelCache);
                    if (!(d < minDistance)) continue;
                    minDistance = upperBound[q] = d;
                    index = i;
                    for (int z = i + 1; z < k; ++z) {
                        if (!(centroidSelfDistances[i][z] >= 2.0 * d)) continue;
                        skip[z] = true;
                    }
                }
                assignment[q] = index;
                double weight = W.get(q);
                meanCount.addAndGet(index, weight);
                deltas[index].mutableAdd(weight, v);
            }
            for (int i = 0; i < deltas.length; ++i) {
                Vec vec = meanSums[i];
                synchronized (vec) {
                    meanSums[i].mutableAdd(deltas[i]);
                }
                deltas[i].zeroOut();
            }
        });
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void step4UpdateCentroids(Vec[] meanSums, ThreadLocal<Vec[]> localDeltas) {
        Vec[] deltas = localDeltas.get();
        for (int i = 0; i < deltas.length; ++i) {
            if (deltas[i].nnz() == 0) continue;
            Vec vec = meanSums[i];
            synchronized (vec) {
                meanSums[i].mutableAdd(deltas[i]);
            }
            deltas[i].zeroOut();
        }
    }

    private void step5_6_distanceMovedBoundsUpdate(int k, Vec[] oldMeans, List<Vec> means, Vec[] meanSums, AtomicDoubleArray meanCount, int N, double[][] lowerBound, double[] upperBound, int[] assignment, boolean[] r, List<List<Double>> meanQIs, boolean parallel) {
        double[] distancesMoved = new double[k];
        ParallelUtils.run(parallel, k, i -> {
            ((Vec)means.get(i)).copyTo(oldMeans[i]);
            meanSums[i].copyTo((Vec)means.get(i));
            double count = meanCount.get(i);
            if (count <= 1.0E-14) {
                ((Vec)means.get(i)).zeroOut();
            } else {
                ((Vec)means.get(i)).mutableDivide(meanCount.get(i));
            }
            distancesMoved[i] = this.dm.dist(oldMeans[i], (Vec)means.get(i));
            if (this.dm.supportsAcceleration()) {
                meanQIs.set(i, this.dm.getQueryInfo((Vec)means.get(i)));
            }
            for (int q = 0; q < N; ++q) {
                lowerBound[q][i] = Math.max(lowerBound[q][i] - distancesMoved[i], 0.0);
            }
        });
        ParallelUtils.run(parallel, N, (start, end) -> {
            for (int q = start; q < end; ++q) {
                int n = q;
                upperBound[n] = upperBound[n] + distancesMoved[assignment[q]];
                r[q] = true;
            }
        });
    }

    private void step3aBoundsUpdate(List<Vec> X, boolean[] r, int q, Vec v, List<Vec> means, int[] assignment, double[] upperBound, double[][] lowerBound, double[] meanSummaryConsts, List<Double> distAccelCache, List<List<Double>> meanQIs) {
        if (r[q]) {
            r[q] = false;
            int meanIndx = assignment[q];
            double d = this.dmds == null ? this.dm.dist(q, means.get(meanIndx), meanQIs.get(meanIndx), X, distAccelCache) : this.dmds.dist(meanSummaryConsts[meanIndx], means.get(meanIndx), v);
            lowerBound[q][meanIndx] = d;
            upperBound[q] = d;
        }
    }

    private void step3bUpdate(List<Vec> X, double[] upperBound, int q, double[][] lowerBound, int c, double[][] centroidSelfDistances, int[] assignment, Vec v, List<Vec> means, ThreadLocal<Vec[]> localDeltas, AtomicDoubleArray meanCount, AtomicBoolean changeOccurred, double[] meanSummaryConsts, List<Double> distAccelCache, List<List<Double>> meanQIs, Vec W) {
        if (upperBound[q] > lowerBound[q][c] || upperBound[q] > centroidSelfDistances[assignment[q]][c] / 2.0) {
            double d = this.dmds == null ? this.dm.dist(q, means.get(c), meanQIs.get(c), X, distAccelCache) : this.dmds.dist(meanSummaryConsts[c], means.get(c), v);
            lowerBound[q][c] = d;
            if (d < upperBound[q]) {
                Vec[] deltas = localDeltas.get();
                double weight = W.get(q);
                deltas[assignment[q]].mutableSubtract(weight, v);
                meanCount.addAndGet(assignment[q], -weight);
                deltas[c].mutableAdd(weight, v);
                meanCount.addAndGet(c, weight);
                assignment[q] = c;
                upperBound[q] = d;
                changeOccurred.set(true);
            }
        }
    }

    private void calculateCentroidDistances(int k, double[][] centroidSelfDistances, List<Vec> means, double[] sC, double[] meanSummaryConsts, boolean parallel) {
        List<Double> meanAccelCache = this.dm.supportsAcceleration() ? this.dm.getAccelerationCache(means) : null;
        ParallelUtils.run(parallel, k, i -> {
            for (int z = i + 1; z < k; ++z) {
                double d = this.dm.dist(i, z, (List<? extends Vec>)means, meanAccelCache);
                centroidSelfDistances[i][z] = d;
                centroidSelfDistances[z][i] = d;
            }
            if (meanSummaryConsts != null) {
                meanSummaryConsts[i] = this.dmds.getVectorConstant((Vec)means.get(i));
            }
        });
        for (int i2 = 0; i2 < k; ++i2) {
            double sCmin = Double.MAX_VALUE;
            for (int z = 0; z < k; ++z) {
                if (z == i2) continue;
                sCmin = Math.min(sCmin, centroidSelfDistances[i2][z]);
            }
            sC[i2] = sCmin / 2.0;
        }
    }

    @Override
    public ElkanKMeans clone() {
        return new ElkanKMeans(this);
    }
}

