/*
 * Decompiled with CFR 0.152.
 */
package jsat.clustering;

import java.util.Arrays;
import java.util.Random;
import jsat.DataSet;
import jsat.SimpleDataSet;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.DataPoint;
import jsat.clustering.KClusterer;
import jsat.clustering.KClustererBase;
import jsat.clustering.evaluation.IntraClusterSumEvaluation;
import jsat.clustering.evaluation.intra.SumOfSqrdPairwiseDistances;
import jsat.clustering.kmeans.HamerlyKMeans;
import jsat.linear.DenseVector;
import jsat.linear.Matrix;
import jsat.linear.SingularValueDecomposition;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.math.OnLineStatistics;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.random.RandomUtil;

public class GapStatistic
extends KClustererBase
implements Parameterized {
    private static final long serialVersionUID = 8893929177942856618L;
    @Parameter.ParameterHolder
    private KClusterer base;
    private int B;
    private DistanceMetric dm;
    private boolean PCSampling;
    private double[] ElogW;
    private double[] logW;
    private double[] gap;
    private double[] s_k;

    public GapStatistic() {
        this(new HamerlyKMeans());
    }

    public GapStatistic(KClusterer base) {
        this(base, false);
    }

    public GapStatistic(KClusterer base, boolean PCSampling) {
        this(base, PCSampling, 10, new EuclideanDistance());
    }

    public GapStatistic(KClusterer base, boolean PCSampling, int B, DistanceMetric dm) {
        this.base = base;
        this.setSamples(B);
        this.setDistanceMetric(dm);
        this.setPCSampling(PCSampling);
    }

    public GapStatistic(GapStatistic toCopy) {
        this.base = toCopy.base.clone();
        this.B = toCopy.B;
        this.dm = toCopy.dm.clone();
        this.PCSampling = toCopy.PCSampling;
        if (toCopy.ElogW != null) {
            this.ElogW = Arrays.copyOf(toCopy.ElogW, toCopy.ElogW.length);
        }
        if (toCopy.logW != null) {
            this.logW = Arrays.copyOf(toCopy.logW, toCopy.logW.length);
        }
        if (toCopy.gap != null) {
            this.gap = Arrays.copyOf(toCopy.gap, toCopy.gap.length);
        }
        if (toCopy.s_k != null) {
            this.s_k = Arrays.copyOf(toCopy.s_k, toCopy.s_k.length);
        }
    }

    public void setDistanceMetric(DistanceMetric dm) {
        this.dm = dm;
    }

    public DistanceMetric getDistanceMetric() {
        return this.dm;
    }

    public void setPCSampling(boolean PCSampling) {
        this.PCSampling = PCSampling;
    }

    public boolean isPCSampling() {
        return this.PCSampling;
    }

    public void setSamples(int B) {
        if (B <= 0) {
            throw new IllegalArgumentException("sample size must be positive, not " + B);
        }
        this.B = B;
    }

    public int getSamples() {
        return this.B;
    }

    public double[] getGap() {
        return this.gap;
    }

    public double[] getLogW() {
        return this.logW;
    }

    public double[] getElogW() {
        return this.ElogW;
    }

    public double[] getElogWkStndDev() {
        return this.s_k;
    }

    @Override
    public int[] cluster(DataSet dataSet, boolean parallel, int[] designations) {
        return this.cluster(dataSet, 1, (int)Math.min(Math.max(Math.sqrt(dataSet.size()), 10.0), 100.0), parallel, designations);
    }

    @Override
    public int[] cluster(DataSet dataSet, int clusters, boolean parallel, int[] designations) {
        return this.base.cluster(dataSet, clusters, parallel, designations);
    }

    @Override
    public int[] cluster(DataSet dataSet, int lowK, int highK, boolean parallel, int[] designations) {
        Matrix V_T;
        int i;
        int D2 = dataSet.getNumNumericalVars();
        int N = dataSet.size();
        if (designations == null || designations.length < N) {
            designations = new int[N];
        }
        this.logW = new double[highK - 1];
        this.ElogW = new double[highK - 1];
        this.gap = new double[highK - 1];
        this.s_k = new double[highK - 1];
        IntraClusterSumEvaluation ssd = new IntraClusterSumEvaluation(new SumOfSqrdPairwiseDistances(this.dm));
        Arrays.fill(designations, 0);
        this.logW[0] = Math.log(ssd.evaluate(designations, dataSet));
        for (int k = 2; k < highK; ++k) {
            designations = this.base.cluster(dataSet, k, parallel, designations);
            this.logW[k - 1] = Math.log(ssd.evaluate(designations, dataSet));
        }
        OnLineStatistics[] expected = new OnLineStatistics[highK - 1];
        for (int i2 = 0; i2 < expected.length; ++i2) {
            expected[i2] = new OnLineStatistics();
        }
        SimpleDataSet Xp = new SimpleDataSet(D2, new CategoricalData[0]);
        for (int i3 = 0; i3 < N; ++i3) {
            Xp.add(new DataPoint(new DenseVector(D2)));
        }
        Random rand = RandomUtil.getRandom();
        double[] min = new double[D2];
        double[] max = new double[D2];
        Arrays.fill(min, Double.POSITIVE_INFINITY);
        Arrays.fill(max, Double.NEGATIVE_INFINITY);
        if (this.PCSampling) {
            SingularValueDecomposition svd = new SingularValueDecomposition(dataSet.getDataMatrix());
            Matrix tmp = dataSet.getDataMatrixView().multiply(svd.getV());
            for (i = 0; i < tmp.rows(); ++i) {
                for (int j = 0; j < tmp.cols(); ++j) {
                    min[j] = Math.min(tmp.get(i, j), min[j]);
                    max[j] = Math.max(tmp.get(i, j), max[j]);
                }
            }
            V_T = svd.getV().transpose();
        } else {
            V_T = null;
            OnLineStatistics[] columnStats = dataSet.getOnlineColumnStats(false);
            for (int i4 = 0; i4 < D2; ++i4) {
                min[i4] = columnStats[i4].getMin();
                max[i4] = columnStats[i4].getMax();
            }
        }
        for (int b = 0; b < this.B; ++b) {
            for (int i5 = 0; i5 < N; ++i5) {
                Vec xp = Xp.getDataPoint(i5).getNumericalValues();
                for (int j = 0; j < D2; ++j) {
                    xp.set(j, (max[j] - min[j]) * rand.nextDouble() + min[j]);
                }
            }
            if (this.PCSampling) {
                DenseVector tmp = new DenseVector(D2);
                for (i = 0; i < N; ++i) {
                    Vec xp = Xp.getDataPoint(i).getNumericalValues();
                    tmp.zeroOut();
                    xp.multiply(V_T, tmp);
                    tmp.copyTo(xp);
                }
            }
            Arrays.fill(designations, 0);
            expected[0].add(Math.log(ssd.evaluate(designations, Xp)));
            for (int k = 2; k < highK; ++k) {
                designations = this.base.cluster((DataSet)Xp, k, parallel, designations);
                expected[k - 1].add(Math.log(ssd.evaluate(designations, Xp)));
            }
        }
        int k_first = -1;
        int biggestGap = 0;
        for (i = 0; i < this.gap.length; ++i) {
            this.ElogW[i] = expected[i].getMean();
            this.gap[i] = this.ElogW[i] - this.logW[i];
            this.s_k[i] = expected[i].getStandardDeviation() * Math.sqrt(1.0 + 1.0 / (double)this.B);
            int k = i + 1;
            if (i > 0 && lowK <= k && k <= highK && k_first == -1 && this.gap[i - 1] >= this.gap[i] - this.s_k[i] && this.gap[i - 1] > 0.0) {
                k_first = k - 1;
            }
            if (!(this.gap[i] > (double)biggestGap) || lowK > k || k > highK) continue;
            biggestGap = i;
        }
        if (k_first == -1) {
            k_first = biggestGap + 1;
        }
        if (k_first == 1) {
            Arrays.fill(designations, 0);
            return designations;
        }
        return this.base.cluster(dataSet, k_first, parallel, designations);
    }

    @Override
    public GapStatistic clone() {
        return new GapStatistic(this);
    }
}

