/*
 * Decompiled with CFR 0.152.
 */
package jsat.clustering;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ExecutionException;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.DataSet;
import jsat.SimpleDataSet;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.DataPoint;
import jsat.clustering.ClusterFailureException;
import jsat.clustering.KClusterer;
import jsat.clustering.SeedSelectionMethods;
import jsat.distributions.multivariate.MultivariateDistribution;
import jsat.distributions.multivariate.NormalM;
import jsat.linear.DenseMatrix;
import jsat.linear.DenseVector;
import jsat.linear.Matrix;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.utils.concurrent.ParallelUtils;
import jsat.utils.random.RandomUtil;

public class EMGaussianMixture
implements KClusterer,
MultivariateDistribution {
    private SeedSelectionMethods.SeedSelection seedSelection;
    private static final long serialVersionUID = 2606159815670221662L;
    private List<NormalM> gaussians;
    private double[] a_k;
    private double tolerance = 0.001;
    protected int MaxIterLimit = Integer.MAX_VALUE;

    public EMGaussianMixture(SeedSelectionMethods.SeedSelection seedSelection) {
        this.setSeedSelection(seedSelection);
    }

    public EMGaussianMixture() {
        this(SeedSelectionMethods.SeedSelection.KPP);
    }

    public void setSeedSelection(SeedSelectionMethods.SeedSelection seedSelection) {
        this.seedSelection = seedSelection;
    }

    public SeedSelectionMethods.SeedSelection getSeedSelection() {
        return this.seedSelection;
    }

    public void setIterationLimit(int iterLimit) {
        if (iterLimit < 1) {
            throw new IllegalArgumentException("Iterations must be a positive value, not " + iterLimit);
        }
        this.MaxIterLimit = iterLimit;
    }

    public int getIterationLimit() {
        return this.MaxIterLimit;
    }

    public EMGaussianMixture(EMGaussianMixture gm) {
        if (gm.gaussians != null && !gm.gaussians.isEmpty()) {
            this.gaussians = new ArrayList<NormalM>(gm.gaussians.size());
            for (NormalM gaussian : gm.gaussians) {
                this.gaussians.add(gaussian.clone());
            }
        }
        if (gm.a_k != null) {
            this.a_k = Arrays.copyOf(gm.a_k, gm.a_k.length);
        }
        this.MaxIterLimit = gm.MaxIterLimit;
        this.tolerance = gm.tolerance;
    }

    private EMGaussianMixture(List<NormalM> gaussians, double[] a_k, double tolerance) {
        this.gaussians = new ArrayList<NormalM>(a_k.length);
        this.a_k = new double[a_k.length];
        for (int i = 0; i < a_k.length; ++i) {
            this.gaussians.add(gaussians.get(i).clone());
            this.a_k[i] = a_k[i];
        }
    }

    protected double cluster(DataSet dataSet, List<Double> accelCache, int K, List<Vec> means, int[] assignment, boolean exactTotal, boolean parallel, boolean returnError) {
        EuclideanDistance dm = new EuclideanDistance();
        ArrayList<List<Double>> means_qi = new ArrayList<List<Double>>();
        if (means.size() < K) {
            means.clear();
            means.addAll(SeedSelectionMethods.selectIntialPoints(dataSet, K, (DistanceMetric)dm, accelCache, RandomUtil.getRandom(), this.seedSelection, parallel));
            for (Vec v : means) {
                means_qi.add(dm.getQueryInfo(v));
            }
        }
        ArrayList<Matrix> covariances = new ArrayList<Matrix>(K);
        int dimension = dataSet.getNumNumericalVars();
        for (int k = 0; k < means.size(); ++k) {
            covariances.add(new DenseMatrix(dimension, dimension));
        }
        this.a_k = new double[K];
        double sum = dataSet.size();
        DenseVector scratch = new DenseVector(dimension);
        List<Vec> X = dataSet.getDataVectors();
        for (int i = 0; i < dataSet.size(); ++i) {
            Vec x = dataSet.getDataPoint(i).getNumericalValues();
            double closest = dm.dist(i, means.get(0), (List)means_qi.get(0), X, accelCache);
            int k = 0;
            for (int j = 1; j < K; ++j) {
                double d_ij = dm.dist(i, means.get(j), (List)means_qi.get(j), X, accelCache);
                if (!(d_ij < closest)) continue;
                closest = d_ij;
                k = j;
            }
            assignment[i] = k;
            int n = k;
            this.a_k[n] = this.a_k[n] + 1.0;
            x.copyTo(scratch);
            scratch.mutableSubtract(means.get(k));
            Matrix.OuterProductUpdate((Matrix)covariances.get(k), scratch, scratch, 1.0);
        }
        int k = 0;
        while (k < means.size()) {
            ((Matrix)covariances.get(k)).mutableMultiply(1.0 / this.a_k[k]);
            int n = k++;
            this.a_k[n] = this.a_k[n] / sum;
        }
        return this.clusterCompute(K, dataSet, assignment, means, covariances, parallel);
    }

    protected double clusterCompute(int K, DataSet dataSet, int[] assignment, List<Vec> means, List<Matrix> covs, boolean parallel) {
        List<DataPoint> dataPoints = dataSet.getDataPoints();
        int N = dataPoints.size();
        double currentLogLike = -1.7976931348623157E308;
        this.gaussians = new ArrayList<NormalM>(K);
        for (int k = 0; k < means.size(); ++k) {
            this.gaussians.add(new NormalM(means.get(k), covs.get(k)));
        }
        double[][] p_ik = new double[dataPoints.size()][K];
        while (true) {
            try {
                double logLike;
                double logDifference;
                while (!((logDifference = Math.abs(currentLogLike - (logLike = this.eStep(N, dataPoints, K, p_ik, parallel)))) < this.tolerance)) {
                    currentLogLike = logLike;
                    this.mStep(means, N, dataPoints, K, p_ik, covs, parallel);
                }
            }
            catch (InterruptedException | ExecutionException ex) {
                Logger.getLogger(EMGaussianMixture.class.getName()).log(Level.SEVERE, null, ex);
                continue;
            }
            break;
        }
        for (int i = 0; i < p_ik.length; ++i) {
            for (int k = 0; k < K; ++k) {
                if (!(p_ik[i][k] > p_ik[i][assignment[i]])) continue;
                assignment[i] = k;
            }
        }
        return -currentLogLike;
    }

    private void mStep(List<Vec> means, int N, List<DataPoint> dataPoints, int K, double[][] p_ik, List<Matrix> covs, boolean parallel) throws InterruptedException {
        int k;
        int D2 = means.get(0).length();
        for (Vec mean : means) {
            mean.zeroOut();
        }
        Arrays.fill(this.a_k, 0.0);
        ThreadLocal<Vec> localMean = ThreadLocal.withInitial(() -> new DenseVector(((DataPoint)dataPoints.get(0)).numNumericalValues()));
        ParallelUtils.run(parallel, N, (start, end) -> {
            int k = 0;
            while (k < K) {
                Vec mean_k_l = (Vec)localMean.get();
                mean_k_l.zeroOut();
                double a_k_l = 0.0;
                for (int i = start; i < end; ++i) {
                    Vec x_i = ((DataPoint)dataPoints.get(i)).getNumericalValues();
                    a_k_l += p_ik[i][k];
                    mean_k_l.mutableAdd(p_ik[i][k], x_i);
                }
                Vec vec = (Vec)means.get(k);
                synchronized (vec) {
                    ((Vec)means.get(k)).mutableAdd(mean_k_l);
                    int n = k++;
                    this.a_k[n] = this.a_k[n] + a_k_l;
                }
            }
        });
        for (int k2 = 0; k2 < this.a_k.length; ++k2) {
            means.get(k2).mutableDivide(this.a_k[k2]);
        }
        for (Matrix cov : covs) {
            cov.zeroOut();
        }
        ParallelUtils.run(parallel, N, (start, end) -> {
            DenseVector scratch = new DenseVector(((Vec)means.get(0)).length());
            Matrix cov_local = ((Matrix)covs.get(0)).clone();
            for (int k = 0; k < K; ++k) {
                Vec mean = (Vec)means.get(k);
                scratch.zeroOut();
                cov_local.zeroOut();
                for (int i = start; i < end; ++i) {
                    DataPoint dp = (DataPoint)dataPoints.get(i);
                    Vec x = dp.getNumericalValues();
                    x.copyTo(scratch);
                    scratch.mutableSubtract(mean);
                    Matrix.OuterProductUpdate(cov_local, scratch, scratch, p_ik[i][k]);
                }
                Matrix matrix = (Matrix)covs.get(k);
                synchronized (matrix) {
                    ((Matrix)covs.get(k)).mutableAdd(cov_local);
                    continue;
                }
            }
        });
        for (k = 0; k < K; ++k) {
            covs.get(k).mutableMultiply(1.0 / this.a_k[k]);
        }
        k = 0;
        while (k < K) {
            int n = k++;
            this.a_k[n] = this.a_k[n] / (double)N;
        }
        for (k = 0; k < means.size(); ++k) {
            this.gaussians.get(k).setMeanCovariance(means.get(k), covs.get(k));
        }
    }

    private double eStep(int N, List<DataPoint> dataPoints, int K, double[][] p_ik, boolean parallel) throws InterruptedException, ExecutionException {
        double logLike = 0.0;
        logLike = ParallelUtils.run(parallel, N, (start, end) -> {
            double logLikeLocal = 0.0;
            for (int i = start; i < end; ++i) {
                int k;
                Vec x_i = ((DataPoint)dataPoints.get(i)).getNumericalValues();
                double p_ikNormalizer = 0.0;
                for (k = 0; k < K; ++k) {
                    double tmp;
                    p_ik[i][k] = tmp = this.a_k[k] * this.gaussians.get(k).pdf(x_i);
                    p_ikNormalizer += tmp;
                }
                k = 0;
                while (k < K) {
                    double[] dArray = p_ik[i];
                    int n = k++;
                    dArray[n] = dArray[n] / p_ikNormalizer;
                }
                logLikeLocal += Math.log(p_ikNormalizer);
            }
            return logLikeLocal;
        }, (t, u) -> t + u);
        return logLike;
    }

    @Override
    public double logPdf(Vec x) {
        double pdf = this.pdf(x);
        if (pdf == 0.0) {
            return -1.7976931348623157E308;
        }
        return Math.log(pdf);
    }

    @Override
    public double pdf(Vec x) {
        double PDF = 0.0;
        for (int i = 0; i < this.a_k.length; ++i) {
            PDF += this.a_k[i] * this.gaussians.get(i).pdf(x);
        }
        return PDF;
    }

    @Override
    public <V extends Vec> boolean setUsingData(List<V> dataSet, boolean parallel) {
        ArrayList<DataPoint> dataPoints = new ArrayList<DataPoint>(dataSet.size());
        for (Vec x : dataSet) {
            dataPoints.add(new DataPoint(x, new int[0], new CategoricalData[0]));
        }
        return this.setUsingData(new SimpleDataSet(dataPoints), parallel);
    }

    @Override
    public boolean setUsingData(DataSet dataSet, boolean parallel) {
        try {
            this.cluster(dataSet, parallel);
            return true;
        }
        catch (ArithmeticException ex) {
            return false;
        }
    }

    @Override
    public EMGaussianMixture clone() {
        return new EMGaussianMixture(this);
    }

    @Override
    public List<Vec> sample(int count, Random rand) {
        ArrayList<Vec> samples = new ArrayList<Vec>(count);
        double[] priorTargets = new double[count];
        for (int i = 0; i < count; ++i) {
            priorTargets[i] = rand.nextDouble();
        }
        Arrays.sort(priorTargets);
        int subSampleSize = 0;
        int currentGaussian = 0;
        int pos = 0;
        double a_kSum = 0.0;
        while (currentGaussian < this.a_k.length) {
            a_kSum += this.a_k[currentGaussian];
            while (pos < count) {
                int n = pos++;
                if (!(priorTargets[n] < a_kSum)) break;
                ++subSampleSize;
            }
            samples.addAll(this.gaussians.get(currentGaussian++).sample(subSampleSize, rand));
        }
        return samples;
    }

    @Override
    public int[] cluster(DataSet dataSet, int[] designations) {
        return this.cluster(dataSet, 2, (int)Math.sqrt(dataSet.size() / 2), designations);
    }

    @Override
    public int[] cluster(DataSet dataSet, boolean parallel, int[] designations) {
        return this.cluster(dataSet, 2, (int)Math.sqrt(dataSet.size() / 2), parallel, designations);
    }

    @Override
    public int[] cluster(DataSet dataSet, int clusters, boolean parallel, int[] designations) {
        if (designations == null) {
            designations = new int[dataSet.size()];
        }
        if (dataSet.size() < clusters) {
            throw new ClusterFailureException("Fewer data points then desired clusters, decrease cluster size");
        }
        ArrayList<Vec> means = new ArrayList<Vec>(clusters);
        this.cluster(dataSet, null, clusters, means, designations, false, parallel, false);
        return designations;
    }

    @Override
    public int[] cluster(DataSet dataSet, int lowK, int highK, boolean parallel, int[] designations) {
        throw new UnsupportedOperationException("EMGaussianMixture does not supported determining the number of clusters");
    }

    @Override
    public int[] cluster(DataSet dataSet, int lowK, int highK, int[] designations) {
        throw new UnsupportedOperationException("EMGaussianMixture does not supported determining the number of clusters");
    }
}

