/*
 * Decompiled with CFR 0.152.
 */
package jsat.clustering;

import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.concurrent.atomic.DoubleAdder;
import java.util.concurrent.atomic.LongAdder;
import jsat.DataSet;
import jsat.clustering.KClusterer;
import jsat.clustering.SeedSelectionMethods;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.linear.distancemetrics.TrainableDistanceMetric;
import jsat.math.OnLineStatistics;
import jsat.utils.IntList;
import jsat.utils.ListUtils;
import jsat.utils.concurrent.ParallelUtils;
import jsat.utils.random.RandomUtil;

public class PAM
implements KClusterer {
    private static final long serialVersionUID = 4787649180692115514L;
    protected DistanceMetric dm;
    protected Random rand;
    protected SeedSelectionMethods.SeedSelection seedSelection;
    protected int iterLimit = 100;
    protected int[] medoids;
    protected boolean storeMedoids = true;

    public PAM(DistanceMetric dm, Random rand, SeedSelectionMethods.SeedSelection seedSelection) {
        this.dm = dm;
        this.rand = rand;
        this.seedSelection = seedSelection;
    }

    public PAM(DistanceMetric dm, Random rand) {
        this(dm, rand, SeedSelectionMethods.SeedSelection.KPP);
    }

    public PAM(DistanceMetric dm) {
        this(dm, RandomUtil.getRandom());
    }

    public PAM() {
        this(new EuclideanDistance());
    }

    public void setMaxIterations(int iterLimit) {
        this.iterLimit = iterLimit;
    }

    public int getMaxIterations() {
        return this.iterLimit;
    }

    public void setDistanceMetric(DistanceMetric dm) {
        this.dm = dm;
    }

    public DistanceMetric getDistanceMetric() {
        return this.dm;
    }

    public PAM(PAM toCopy) {
        this.dm = toCopy.dm.clone();
        this.rand = RandomUtil.getRandom();
        this.seedSelection = toCopy.seedSelection;
        if (toCopy.medoids != null) {
            this.medoids = Arrays.copyOf(toCopy.medoids, toCopy.medoids.length);
        }
        this.storeMedoids = toCopy.storeMedoids;
        this.iterLimit = toCopy.iterLimit;
    }

    public void setStoreMedoids(boolean storeMedoids) {
        this.storeMedoids = storeMedoids;
    }

    public int[] getMedoids() {
        return this.medoids;
    }

    public void setSeedSelection(SeedSelectionMethods.SeedSelection seedSelection) {
        this.seedSelection = seedSelection;
    }

    public SeedSelectionMethods.SeedSelection getSeedSelection() {
        return this.seedSelection;
    }

    protected double cluster(DataSet data, boolean doInit, int[] medioids, int[] assignments, List<Double> cacheAccel, boolean parallel) {
        List<Double> accel;
        DoubleAdder totalDistance = new DoubleAdder();
        LongAdder changes = new LongAdder();
        Arrays.fill(assignments, -1);
        int[] bestMedCand = new int[medioids.length];
        double[] bestMedCandDist = new double[medioids.length];
        List<Vec> X = data.getDataVectors();
        if (doInit) {
            TrainableDistanceMetric.trainIfNeeded(this.dm, data);
            accel = this.dm.getAccelerationCache(X);
            SeedSelectionMethods.selectIntialPoints(data, medioids, this.dm, accel, this.rand, this.seedSelection);
        } else {
            accel = cacheAccel;
        }
        int iter = 0;
        do {
            changes.reset();
            totalDistance.reset();
            ParallelUtils.run(parallel, data.size(), (start, end) -> {
                for (int i = start; i < end; ++i) {
                    int assignment = 0;
                    double minDist = this.dm.dist(medioids[0], i, (List<? extends Vec>)X, accel);
                    for (int k = 1; k < medioids.length; ++k) {
                        double dist = this.dm.dist(medioids[k], i, (List<? extends Vec>)X, accel);
                        if (!(dist < minDist)) continue;
                        minDist = dist;
                        assignment = k;
                    }
                    if (assignments[i] != assignment) {
                        changes.increment();
                        assignments[i] = assignment;
                    }
                    totalDistance.add(minDist * minDist);
                }
            });
            Arrays.fill(bestMedCandDist, Double.MAX_VALUE);
            for (int i = 0; i < data.size(); ++i) {
                int clusterID = assignments[i];
                int medCandadate = i;
                int ii = i;
                double thisCandidateDistance = ParallelUtils.range(data.size(), parallel).filter(j -> j != ii && assignments[j] == clusterID).mapToDouble(j -> Math.pow(this.dm.dist(medCandadate, j, (List<? extends Vec>)X, accel), 2.0)).sum();
                if (!(thisCandidateDistance < bestMedCandDist[clusterID])) continue;
                bestMedCand[clusterID] = i;
                bestMedCandDist[clusterID] = thisCandidateDistance;
            }
            System.arraycopy(bestMedCand, 0, medioids, 0, medioids.length);
        } while (changes.sum() > 0L && iter++ < this.iterLimit);
        return totalDistance.sum();
    }

    @Override
    public int[] cluster(DataSet dataSet, boolean parallel, int[] designations) {
        return this.cluster(dataSet, 2, (int)Math.sqrt(dataSet.size() / 2), parallel, designations);
    }

    @Override
    public int[] cluster(DataSet dataSet, int clusters, boolean parallel, int[] designations) {
        if (designations == null) {
            designations = new int[dataSet.size()];
        }
        this.medoids = new int[clusters];
        this.cluster(dataSet, true, this.medoids, designations, null, parallel);
        if (!this.storeMedoids) {
            this.medoids = null;
        }
        return designations;
    }

    @Override
    public PAM clone() {
        return new PAM(this);
    }

    @Override
    public int[] cluster(DataSet dataSet, int lowK, int highK, boolean parallel, int[] designations) {
        if (designations == null) {
            designations = new int[dataSet.size()];
        }
        double[] totDistances = new double[highK - lowK + 1];
        for (int k = lowK; k <= highK; ++k) {
            totDistances[k - lowK] = this.cluster(dataSet, true, new int[k], designations, null, parallel);
        }
        OnLineStatistics stats = new OnLineStatistics();
        double maxChange = Double.MIN_VALUE;
        int maxChangeK = lowK;
        for (int i = 1; i < totDistances.length; ++i) {
            double change = Math.abs(totDistances[i] - totDistances[i - 1]);
            stats.add(change);
            if (!(change > maxChange)) continue;
            maxChange = change;
            maxChangeK = i + lowK;
        }
        if (maxChange < stats.getStandardDeviation() * 2.0 + stats.getMean()) {
            maxChangeK = lowK;
        }
        return this.cluster(dataSet, maxChangeK, parallel, designations);
    }

    public static int medoid(boolean parallel, List<? extends Vec> X, DistanceMetric dm) {
        IntList order = new IntList(X.size());
        ListUtils.addRange(order, 0, X.size(), 1);
        List<Double> accel = dm.getAccelerationCache(X, parallel);
        return PAM.medoid(parallel, order, X, dm, accel);
    }

    public static int medoid(boolean parallel, Collection<Integer> indecies, List<? extends Vec> X, DistanceMetric dm, List<Double> accel) {
        double bestDist = Double.POSITIVE_INFINITY;
        int bestIndex = -1;
        Iterator<Integer> iterator = indecies.iterator();
        while (iterator.hasNext()) {
            int i;
            int medCandadate = i = iterator.next().intValue();
            double thisCandidateDistance = ParallelUtils.range(indecies.size(), parallel).filter(j -> j != i).mapToDouble(j -> dm.dist(medCandadate, j, X, accel)).sum();
            if (!(thisCandidateDistance < bestDist)) continue;
            bestIndex = i;
            bestDist = thisCandidateDistance;
        }
        return bestIndex;
    }
}

