/*
 * Decompiled with CFR 0.152.
 */
package jsat.clustering;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.ConcurrentSkipListSet;
import java.util.concurrent.atomic.AtomicIntegerArray;
import java.util.concurrent.atomic.LongAdder;
import java.util.stream.DoubleStream;
import jsat.DataSet;
import jsat.clustering.PAM;
import jsat.clustering.SeedSelectionMethods;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.TrainableDistanceMetric;
import jsat.utils.IntList;
import jsat.utils.ListUtils;
import jsat.utils.concurrent.AtomicDouble;
import jsat.utils.concurrent.AtomicDoubleArray;
import jsat.utils.concurrent.ParallelUtils;
import jsat.utils.random.RandomUtil;

public class TRIKMEDS
extends PAM {
    public TRIKMEDS(DistanceMetric dm, Random rand, SeedSelectionMethods.SeedSelection seedSelection) {
        super(dm, rand, seedSelection);
    }

    public TRIKMEDS(DistanceMetric dm, Random rand) {
        super(dm, rand);
    }

    public TRIKMEDS(DistanceMetric dm) {
        super(dm);
    }

    public TRIKMEDS() {
    }

    @Override
    public void setDistanceMetric(DistanceMetric dm) {
        if (!dm.isValidMetric()) {
            throw new IllegalArgumentException("TRIKMEDS requires a valid distance metric, but " + dm.toString() + " does not obey all distance metric properties");
        }
        super.setDistanceMetric(dm);
    }

    @Override
    protected double cluster(DataSet data, boolean doInit, int[] medioids, int[] assignments, List<Double> cacheAccel, boolean parallel) {
        int k2;
        List<Double> accel;
        LongAdder changes = new LongAdder();
        Arrays.fill(assignments, -1);
        List<Vec> X = data.getDataVectors();
        if (doInit) {
            TrainableDistanceMetric.trainIfNeeded(this.dm, data);
            accel = this.dm.getAccelerationCache(X);
            SeedSelectionMethods.selectIntialPoints(data, medioids, this.dm, accel, this.rand, this.seedSelection);
        } else {
            accel = cacheAccel;
        }
        int N = data.size();
        int K = medioids.length;
        AtomicIntegerArray m = new AtomicIntegerArray(K);
        int[] c = medioids;
        int[] a = assignments;
        double[] d = new double[N];
        double[] d_tilde = new double[N];
        AtomicDoubleArray v = new AtomicDoubleArray(K);
        double[][] lc = new double[N][K];
        AtomicDoubleArray ls = new AtomicDoubleArray(N);
        double[] p = new double[K];
        AtomicDoubleArray s = new AtomicDoubleArray(K);
        ArrayList ownedBy = new ArrayList(K);
        for (int i2 = 0; i2 < K; ++i2) {
            ownedBy.add(new ConcurrentSkipListSet());
        }
        AtomicDoubleArray delta_n_in = new AtomicDoubleArray(K);
        AtomicDoubleArray delta_n_out = new AtomicDoubleArray(K);
        AtomicDoubleArray delta_s_in = new AtomicDoubleArray(K);
        AtomicDoubleArray delta_s_out = new AtomicDoubleArray(K);
        for (k2 = 0; k2 < K; ++k2) {
            m.set(k2, c[k2]);
        }
        ParallelUtils.run(parallel, N, (start, end) -> {
            for (int i = start; i < end; ++i) {
                double a_min_val = Double.POSITIVE_INFINITY;
                int a_min_k = 0;
                for (int k = 0; k < K; ++k) {
                    lc[i][k] = this.dm.dist(i, m.get(k), (List<? extends Vec>)X, accel);
                    if (!(lc[i][k] <= a_min_val)) continue;
                    a_min_val = lc[i][k];
                    a_min_k = k;
                }
                a[i] = a_min_k;
                d[i] = a_min_val;
                v.getAndAdd(a[i], 1.0);
                ((Set)ownedBy.get(a_min_k)).add(i);
                s.addAndGet(a[i], d[i]);
                ls.set(i, 0.0);
            }
        });
        for (k2 = 0; k2 < K; ++k2) {
            ls.set(m.get(k2), s.get(k2));
        }
        int iter = 0;
        do {
            changes.reset();
            boolean[] medioid_changed = new boolean[K];
            Arrays.fill(medioid_changed, false);
            ParallelUtils.run(parallel, N, i -> {
                for (int k = 0; k < K; ++k) {
                    if (!(ls.get(i) < s.get(k))) continue;
                    double ls_i_new = 0.0;
                    Iterator iterator = ((Set)ownedBy.get(k)).iterator();
                    while (iterator.hasNext()) {
                        int j = (Integer)iterator.next();
                        d_tilde[j] = this.dm.dist(i, j, (List<? extends Vec>)X, accel);
                        ls_i_new += d_tilde[j];
                    }
                    ls.set(i, ls_i_new);
                    if (ls_i_new < s.get(k)) {
                        iterator = s;
                        synchronized (iterator) {
                            if (ls_i_new < s.get(k)) {
                                s.set(k, ls_i_new);
                                m.set(k, i);
                                medioid_changed[k] = true;
                                Iterator j = ((Set)ownedBy.get(k)).iterator();
                                while (j.hasNext()) {
                                    int j2 = (Integer)j.next();
                                    d[j2] = d_tilde[j2];
                                }
                            }
                        }
                    }
                    iterator = ((Set)ownedBy.get(k)).iterator();
                    while (iterator.hasNext()) {
                        int j = (Integer)iterator.next();
                        ls.accumulateAndGet(j, d[j] * v.get(k), (ls_j, d_jXv_k) -> Math.max(ls_j, Math.abs(d_jXv_k - ls_j)));
                    }
                }
            });
            ParallelUtils.run(parallel, K, k -> {
                if (medioid_changed[k]) {
                    p[k] = this.dm.dist(c[k], m.get(k), (List<? extends Vec>)X, accel);
                    c[k] = m.get(k);
                }
                delta_n_in.set(k, 0.0);
                delta_n_out.set(k, 0.0);
                delta_s_in.set(k, 0.0);
                delta_s_out.set(k, 0.0);
            });
            ParallelUtils.run(parallel, N, i -> {
                for (int k = 0; k < K; ++k) {
                    double[] dArray = lc[i];
                    int n = k;
                    dArray[n] = dArray[n] - p[k];
                }
                lc[i][a[i]] = d[i];
                int a_old = a[i];
                double d_old = d[i];
                for (int k = 0; k < K; ++k) {
                    if (!(lc[i][k] < d[i])) continue;
                    lc[i][k] = this.dm.dist(i, c[k], (List<? extends Vec>)X, accel);
                    if (!(lc[i][k] < d[i])) continue;
                    a[i] = k;
                    d[i] = lc[i][k];
                }
                if (a_old != a[i]) {
                    v.getAndDecrement(a_old);
                    v.getAndIncrement(a[i]);
                    changes.increment();
                    ((Set)ownedBy.get(a_old)).remove(i);
                    ((Set)ownedBy.get(a[i])).add(i);
                    ls.set(i, 0.0);
                    delta_n_in.getAndIncrement(a[i]);
                    delta_n_out.getAndIncrement(a_old);
                    delta_s_in.getAndAdd(a[i], d[i]);
                    delta_s_in.getAndAdd(a_old, d_old);
                }
            });
            double[] J_abs_s = new double[K];
            double[] J_net_s = new double[K];
            double[] J_abs_n = new double[K];
            double[] J_net_n = new double[K];
            for (int k3 = 0; k3 < K; ++k3) {
                J_abs_s[k3] = delta_s_in.get(k3) + delta_s_out.get(k3);
                J_net_s[k3] = delta_s_in.get(k3) - delta_s_out.get(k3);
                J_abs_n[k3] = delta_n_in.get(k3) + delta_n_out.get(k3);
                J_net_n[k3] = delta_n_in.get(k3) - delta_n_out.get(k3);
            }
            ParallelUtils.run(parallel, N, (start, end) -> {
                for (int i = start; i < end; ++i) {
                    double ls_i_delta = 0.0;
                    for (int k = 0; k < K; ++k) {
                        ls_i_delta -= Math.min(J_abs_s[k] - J_net_n[k] * d[i], J_abs_n[k] * d[i] - J_net_s[k]);
                    }
                    ls.getAndAdd(i, ls_i_delta);
                }
            });
        } while (changes.sum() > 0L && iter++ < this.iterLimit);
        return ParallelUtils.streamP(DoubleStream.of(d), parallel).map(x -> x * x).sum();
    }

    public static int medoid(boolean parallel, List<? extends Vec> X, DistanceMetric dm) {
        IntList order = new IntList(X.size());
        ListUtils.addRange(order, 0, X.size(), 1);
        List<Double> accel = dm.getAccelerationCache(X, parallel);
        return TRIKMEDS.medoid(parallel, order, X, dm, accel);
    }

    public static int medoid(boolean parallel, Collection<Integer> indecies, List<? extends Vec> X, DistanceMetric dm, List<Double> accel) {
        int N = X.size();
        AtomicDoubleArray l = new AtomicDoubleArray(N);
        AtomicDouble e_cl = new AtomicDouble(Double.POSITIVE_INFINITY);
        IntList rand_order = new IntList(indecies);
        Collections.shuffle(rand_order, RandomUtil.getRandom());
        ThreadLocal<double[]> d_local = ThreadLocal.withInitial(() -> new double[N]);
        ParallelUtils.streamP(rand_order.streamInts(), parallel).forEach(i -> {
            double[] d = (double[])d_local.get();
            double d_avg = 0.0;
            if (l.get(i) < e_cl.get()) {
                Iterator iterator = indecies.iterator();
                while (iterator.hasNext()) {
                    int j = (Integer)iterator.next();
                    d[j] = dm.dist(i, j, X, accel);
                    d_avg += d[j];
                }
                double l_i = d_avg /= (double)(indecies.size() - 1);
                l.set(i, l_i);
                if (l_i < e_cl.get()) {
                    e_cl.getAndUpdate(val -> Math.min(val, l_i));
                }
                Iterator iterator2 = indecies.iterator();
                while (iterator2.hasNext()) {
                    int j = (Integer)iterator2.next();
                    l.getAndUpdate(j, l_j -> Math.max(l_j, Math.abs(l_i - d[j])));
                }
            }
        });
        for (int i2 : indecies) {
            if (l.get(i2) != e_cl.get()) continue;
            return i2;
        }
        return -1;
    }
}

