/*
 * Decompiled with CFR 0.152.
 */
package jsat.clustering;

import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Random;
import jsat.DataSet;
import jsat.SimpleDataSet;
import jsat.classifiers.DataPoint;
import jsat.clustering.PAM;
import jsat.clustering.SeedSelectionMethods;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.linear.distancemetrics.TrainableDistanceMetric;
import jsat.utils.DoubleList;
import jsat.utils.random.RandomUtil;

public class CLARA
extends PAM {
    private static final long serialVersionUID = 174392533688953706L;
    private int sampleSize;
    private int sampleCount;
    private boolean autoSampleSize;

    public CLARA(int sampleSize, int sampleCount, DistanceMetric dm, Random rand, SeedSelectionMethods.SeedSelection seedSelection) {
        super(dm, rand, seedSelection);
        this.sampleSize = sampleSize;
        this.sampleCount = sampleCount;
        this.autoSampleSize = false;
    }

    public CLARA(int sampleCount, DistanceMetric dm, Random rand, SeedSelectionMethods.SeedSelection seedSelection) {
        super(dm, rand, seedSelection);
        this.sampleSize = -1;
        this.sampleCount = sampleCount;
        this.autoSampleSize = true;
    }

    public CLARA(DistanceMetric dm, Random rand, SeedSelectionMethods.SeedSelection seedSelection) {
        this(5, dm, rand, seedSelection);
    }

    public CLARA(DistanceMetric dm, Random rand) {
        this(dm, rand, SeedSelectionMethods.SeedSelection.KPP);
    }

    public CLARA(DistanceMetric dm) {
        this(dm, RandomUtil.getRandom());
    }

    public CLARA() {
        this(new EuclideanDistance());
    }

    public CLARA(CLARA toCopy) {
        super(toCopy);
        this.sampleSize = toCopy.sampleSize;
        this.sampleCount = toCopy.sampleCount;
        this.autoSampleSize = toCopy.autoSampleSize;
    }

    public int getSampleCount() {
        return this.sampleCount;
    }

    public void setSampleCount(int sampleCount) {
        this.sampleCount = sampleCount;
    }

    public int getSampleSize() {
        return this.sampleSize;
    }

    public void setSampleSize(int sampleSize) {
        if (sampleSize >= 0) {
            this.autoSampleSize = false;
            this.sampleSize = sampleSize;
        } else {
            this.autoSampleSize = true;
        }
    }

    @Override
    protected double cluster(DataSet data, boolean doInit, int[] medioids, int[] assignments, List<Double> cacheAccel, boolean parallel) {
        int k = medioids.length;
        int[] bestMedoids = new int[medioids.length];
        int[] bestAssignments = new int[assignments.length];
        double bestMedoidsDist = Double.MAX_VALUE;
        List<Vec> X = data.getDataVectors();
        if (this.sampleSize >= data.size()) {
            return super.cluster(data, true, medioids, assignments, cacheAccel, parallel);
        }
        if (doInit) {
            TrainableDistanceMetric.trainIfNeeded(this.dm, data);
            cacheAccel = this.dm.getAccelerationCache(X);
        }
        int sampSize = this.autoSampleSize ? 40 + 2 * k : this.sampleSize;
        int[] sampleAssignments = new int[sampSize];
        ArrayList<DataPoint> sample = new ArrayList<DataPoint>(sampSize);
        LinkedHashMap<Integer, Integer> samplePoints = new LinkedHashMap<Integer, Integer>();
        DoubleList subCache = new DoubleList(sampSize);
        for (int i = 0; i < this.sampleCount; ++i) {
            samplePoints.clear();
            sample.clear();
            subCache.clear();
            while (samplePoints.size() < sampSize) {
                int indx = this.rand.nextInt(data.size());
                if (samplePoints.containsValue(indx)) continue;
                samplePoints.put(samplePoints.size(), indx);
            }
            for (Integer j : samplePoints.values()) {
                sample.add(data.getDataPoint(j));
                subCache.add(cacheAccel.get(j));
            }
            SimpleDataSet sampleSet = new SimpleDataSet(sample);
            SeedSelectionMethods.selectIntialPoints((DataSet)sampleSet, medioids, this.dm, (List<Double>)subCache, this.rand, this.getSeedSelection());
            super.cluster(sampleSet, false, medioids, sampleAssignments, subCache, parallel);
            for (int j = 0; j < medioids.length; ++j) {
                medioids[j] = (Integer)samplePoints.get(medioids[j]);
            }
            double sqrdDist = 0.0;
            for (int j = 0; j < data.size(); ++j) {
                double smallestDist = Double.MAX_VALUE;
                int assignment = -1;
                for (int z = 0; z < k; ++z) {
                    double tmp = this.dm.dist(medioids[z], j, X, cacheAccel);
                    if (!(tmp < smallestDist)) continue;
                    assignment = z;
                    smallestDist = tmp;
                }
                assignments[j] = assignment;
                sqrdDist += smallestDist * smallestDist;
            }
            if (!(sqrdDist < bestMedoidsDist)) continue;
            bestMedoidsDist = sqrdDist;
            System.arraycopy(medioids, 0, bestMedoids, 0, k);
            System.arraycopy(assignments, 0, bestAssignments, 0, assignments.length);
        }
        System.arraycopy(bestMedoids, 0, medioids, 0, k);
        System.arraycopy(bestAssignments, 0, assignments, 0, assignments.length);
        return bestMedoidsDist;
    }

    @Override
    public CLARA clone() {
        return new CLARA(this);
    }
}

