/*
 * Decompiled with CFR 0.152.
 */
package jsat.clustering;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.stream.Collectors;
import jsat.DataSet;
import jsat.SimpleDataSet;
import jsat.classifiers.DataPoint;
import jsat.clustering.Clusterer;
import jsat.clustering.kmeans.HamerlyKMeans;
import jsat.distributions.multivariate.MultivariateDistribution;
import jsat.distributions.multivariate.NormalM;
import jsat.linear.DenseMatrix;
import jsat.linear.DenseVector;
import jsat.linear.Matrix;
import jsat.linear.MatrixStatistics;
import jsat.linear.Vec;
import jsat.math.MathTricks;
import jsat.math.SpecialMath;
import jsat.utils.concurrent.ParallelUtils;

public class VBGMM
implements Clusterer,
MultivariateDistribution {
    protected double alpha_0 = 1.0E-5;
    protected double beta_0 = 1.0;
    private double prune_tol = 1.0E-5;
    protected NormalM[] normals;
    protected double[] log_pi;
    protected int max_k = 200;
    private int maxIterations = 2000;
    protected COV_FIT_TYPE cov_type = COV_FIT_TYPE.FULL;

    public VBGMM() {
        this(COV_FIT_TYPE.FULL);
    }

    public VBGMM(COV_FIT_TYPE cov_type) {
        this.cov_type = cov_type;
    }

    public VBGMM(VBGMM toCopy) {
        this.max_k = toCopy.max_k;
        this.maxIterations = toCopy.maxIterations;
        this.prune_tol = toCopy.prune_tol;
        this.beta_0 = toCopy.beta_0;
        this.alpha_0 = toCopy.alpha_0;
        if (toCopy.normals != null) {
            this.normals = Arrays.copyOf(toCopy.normals, toCopy.normals.length);
            for (int i = 0; i < this.normals.length; ++i) {
                this.normals[i] = this.normals[i].clone();
            }
            this.log_pi = Arrays.copyOf(toCopy.log_pi, toCopy.log_pi.length);
        }
    }

    @Override
    public int[] cluster(DataSet dataSet, boolean parallel, int[] designations) {
        int k_max = Math.min(this.max_k, dataSet.size() / 2);
        int N = dataSet.size();
        int d = dataSet.getNumNumericalVars();
        List<Vec> X = dataSet.getDataVectors();
        this.normals = new NormalM[k_max];
        boolean[] active = new boolean[k_max];
        Arrays.fill(active, true);
        double[][] r = new double[k_max][N];
        double[] N_k = new double[k_max];
        Vec[] X_bar_k = new Vec[k_max];
        Matrix[] S_k = new Matrix[k_max];
        double[] beta = new double[k_max];
        Arrays.fill(beta, (double)d);
        double log_prune_tol = Math.log(this.prune_tol);
        double[] alpha = new double[k_max];
        DenseVector m_0 = new DenseVector(d);
        MatrixStatistics.meanVector((Vec)m_0, dataSet);
        Arrays.fill(r[0], 1.0);
        Matrix W_inv_0 = this.cov_type.allocate(d);
        this.cov_type.fit(X, W_inv_0, r[0], m_0, N);
        Arrays.fill(r[0], 0.0);
        Vec[] m_k = new Vec[k_max];
        Matrix[] W_inv_k = new Matrix[k_max];
        for (int k2 = 0; k2 < k_max; ++k2) {
            m_k[k2] = new DenseVector(d);
            W_inv_k[k2] = this.cov_type.allocate(d);
            S_k[k2] = this.cov_type.allocate(d);
        }
        double nu_0 = d;
        double[] nu_k = new double[k_max];
        Arrays.fill(nu_k, 1.0);
        this.log_pi = new double[k_max];
        double[] log_precision = new double[k_max];
        HamerlyKMeans kMeans = new HamerlyKMeans();
        designations = kMeans.cluster(dataSet, k_max, parallel, designations);
        for (int n2 = 0; n2 < N; ++n2) {
            r[designations[n2]][n2] = 1.0;
            int n3 = designations[n2];
            this.log_pi[n3] = this.log_pi[n3] + 1.0;
        }
        for (int k3 = 0; k3 < k_max; ++k3) {
            kMeans.getMeans().get(k3).copyTo(m_k[k3]);
            if (this.log_pi[k3] == 0.0) {
                active[k3] = false;
            }
            this.log_pi[k3] = Math.log(this.log_pi[k3]) - Math.log(N);
        }
        double prevLog = Double.POSITIVE_INFINITY;
        for (int iteration = 0; iteration < this.maxIterations; ++iteration) {
            ParallelUtils.run(parallel, k_max, k -> {
                if (!active[k]) {
                    return;
                }
                double Nk = 0.0;
                DenseVector xk = new DenseVector(d);
                for (int n = 0; n < N; ++n) {
                    double r_nk = r[k][n];
                    Vec x_n = (Vec)X.get(n);
                    Nk += r_nk;
                    xk.mutableAdd(r_nk, x_n);
                }
                N_k[k] = Nk;
                xk.mutableDivide(Nk + 1.0E-6);
                X_bar_k[k] = xk;
                this.cov_type.fit(X, S_k[k], r[k], xk, Nk);
                alpha[k] = this.alpha_0 + Nk;
                beta[k] = this.beta_0 + Nk;
                m_k[k].zeroOut();
                m_k[k].mutableAdd(this.beta_0, m_0);
                m_k[k].mutableAdd(Nk, xk);
                m_k[k].mutableDivide(beta[k] + 1.0E-6);
                nu_k[k] = nu_0 + Nk;
                this.cov_type.updateWishart(W_inv_0, W_inv_k[k], S_k[k], xk, Nk, m_0, this.beta_0, beta[k], nu_k[k]);
            });
            double alpha_sum = DenseVector.toDenseVec(alpha).sum();
            ParallelUtils.run(parallel, k_max, k -> {
                if (!active[k]) {
                    return;
                }
                this.normals[k] = this.cov_type.asNormal(m_k[k], W_inv_k[k]);
                this.log_pi[k] = SpecialMath.digamma(alpha[k]) - SpecialMath.digamma(alpha_sum);
                if (this.log_pi[k] < log_prune_tol) {
                    active[k] = false;
                }
                log_precision[k] = (double)d * Math.log(2.0);
                for (int i = 0; i < d; ++i) {
                    int n = k;
                    log_precision[n] = log_precision[n] + SpecialMath.digamma((nu_k[k] - (double)i) / 2.0);
                }
                int n = k;
                log_precision[n] = log_precision[n] / 2.0;
            });
            double log_prob_sum = ParallelUtils.run(parallel, k_max, k -> {
                if (!active[k]) {
                    return 0.0;
                }
                double log_prob_contrib = 0.0;
                for (int n = 0; n < N; ++n) {
                    double proj = this.normals[k].logPdf((Vec)X.get(n));
                    double d2 = (proj -= (double)d / (2.0 * beta[k])) + this.log_pi[k] + log_precision[k];
                    r[k][n] = d2;
                    log_prob_contrib += d2;
                }
                return log_prob_contrib;
            }, (a, b) -> a + b);
            if (Math.abs((prevLog - log_prob_sum) / prevLog) < 1.0E-5) break;
            prevLog = log_prob_sum;
            ParallelUtils.run(parallel, N, n -> {
                int k;
                double sum = 0.0;
                for (k = 0; k < k_max; ++k) {
                    if (!active[k]) continue;
                    double d = Math.exp(r[k][n]);
                    r[k][n] = d;
                    sum += d;
                }
                for (k = 0; k < k_max; ++k) {
                    if (!active[k]) continue;
                    double[] dArray = r[k];
                    int n2 = n;
                    dArray[n2] = dArray[n2] / sum;
                }
            });
        }
        int still_active = active.length;
        for (boolean still_good : active) {
            if (still_good) continue;
            --still_active;
        }
        int final_k = still_active;
        int cur_pos = 0;
        for (int k4 = 0; k4 < k_max; ++k4) {
            if (!active[k4]) continue;
            this.normals[cur_pos] = this.normals[k4];
            this.log_pi[cur_pos++] = this.log_pi[k4];
        }
        this.normals = Arrays.copyOf(this.normals, final_k);
        this.log_pi = Arrays.copyOf(this.log_pi, final_k);
        for (int n4 = 0; n4 < N; ++n4) {
            int cur_pos2 = 0;
            int k_max_indx = 0;
            double k_max_value = 0.0;
            for (int k5 = 0; k5 < k_max; ++k5) {
                if (!active[k5]) continue;
                double d2 = r[cur_pos2][n4] = r[k5][n4];
                if (d2 > k_max_value) {
                    k_max_indx = cur_pos2;
                    k_max_value = r[cur_pos2][n4];
                }
                ++cur_pos2;
            }
            designations[n4] = k_max_indx;
        }
        return designations;
    }

    public void setAlphaPrior(double alpha_0) {
        this.alpha_0 = alpha_0;
    }

    public double getAlphaPrior() {
        return this.alpha_0;
    }

    public void setBetaPrior(double beta_0) {
        this.beta_0 = beta_0;
    }

    public double getBetaPrior() {
        return this.beta_0;
    }

    public void setMaxIterations(int maxIterations) {
        this.maxIterations = maxIterations;
    }

    public int getMaxIterations() {
        return this.maxIterations;
    }

    @Override
    public VBGMM clone() {
        return new VBGMM(this);
    }

    @Override
    public double logPdf(Vec x) {
        double pdf = this.pdf(x);
        if (pdf == 0.0) {
            return -1.7976931348623157E308;
        }
        return Math.log(pdf);
    }

    @Override
    public double pdf(Vec x) {
        double pdf = 0.0;
        for (int i = 0; i < this.normals.length; ++i) {
            pdf += Math.exp(this.log_pi[i] + this.normals[i].logPdf(x));
        }
        return pdf;
    }

    public double[] mixtureAssignments(Vec x) {
        double[] assignments = new double[this.normals.length];
        for (int i = 0; i < this.normals.length; ++i) {
            assignments[i] = this.log_pi[i] + this.normals[i].logPdf(x);
        }
        MathTricks.softmax(assignments, false);
        return assignments;
    }

    @Override
    public <V extends Vec> boolean setUsingData(List<V> dataSet, boolean parallel) {
        SimpleDataSet sds = new SimpleDataSet(dataSet.stream().map(v -> new DataPoint((Vec)v)).collect(Collectors.toList()));
        this.cluster((DataSet)sds, parallel);
        return true;
    }

    @Override
    public List<Vec> sample(int count, Random rand) {
        ArrayList<Vec> samples = new ArrayList<Vec>(count);
        double[] priorTargets = new double[count];
        for (int i = 0; i < count; ++i) {
            priorTargets[i] = rand.nextDouble();
        }
        Arrays.sort(priorTargets);
        int subSampleSize = 0;
        int currentGaussian = 0;
        int pos = 0;
        double a_kSum = 0.0;
        while (currentGaussian < this.normals.length) {
            a_kSum += Math.exp(this.log_pi[currentGaussian]);
            while (pos < count) {
                int n = pos++;
                if (!(priorTargets[n] < a_kSum)) break;
                ++subSampleSize;
            }
            samples.addAll(this.normals[currentGaussian++].sample(subSampleSize, rand));
        }
        return samples;
    }

    public static enum COV_FIT_TYPE {
        DIAG{

            @Override
            public void fit(List<Vec> X, Matrix S_k, double[] contrib, Vec xk, double Nk) {
                int N = contrib.length;
                int d = xk.length();
                S_k.zeroOut();
                Vec diag = S_k.getRowView(0);
                for (int n = 0; n < N; ++n) {
                    double r_nk = contrib[n];
                    Vec x_n = X.get(n);
                    for (int j = 0; j < d; ++j) {
                        diag.increment(j, r_nk * Math.pow(xk.get(j) - x_n.get(j), 2.0));
                    }
                }
                diag.mutableDivide(Nk + 1.0E-6);
            }

            @Override
            public void updateWishart(Matrix W_inv_0, Matrix W_inv_k, Matrix S_k, Vec xk, double Nk, Vec m_0, double beta_0, double beta_k, double nu_k) {
                int d = W_inv_0.cols();
                W_inv_0.copyTo(W_inv_k);
                W_inv_k.mutableAdd(Nk, S_k);
                Vec W_inv_k_diag = W_inv_k.getRowView(0);
                W_inv_k_diag.mutableAdd(1.0E-6);
                double \u03b20_Nk_over_\u03b20_plus_Nk = beta_0 * Nk / beta_k;
                Vec tmp = xk.clone();
                tmp.mutableSubtract(m_0);
                tmp.applyFunction(v -> v * v);
                W_inv_k_diag.mutableAdd(\u03b20_Nk_over_\u03b20_plus_Nk, tmp);
                W_inv_k_diag.mutableDivide(nu_k + 1.0E-6);
            }

            @Override
            public Matrix allocate(int d) {
                return new DenseMatrix(1, d);
            }

            @Override
            public NormalM asNormal(Vec mean, Matrix cov) {
                return new NormalM(mean, cov.getRowView(0));
            }
        }
        ,
        FULL{

            @Override
            public void fit(List<Vec> X, Matrix S_k, double[] contrib, Vec xk, double Nk) {
                int N = contrib.length;
                int d = xk.length();
                S_k.zeroOut();
                DenseVector tmp = new DenseVector(d);
                for (int n = 0; n < N; ++n) {
                    double r_nk = contrib[n];
                    X.get(n).copyTo(tmp);
                    tmp.mutableSubtract(xk);
                    Matrix.OuterProductUpdate(S_k, tmp, tmp, r_nk);
                }
                S_k.mutableMultiply(1.0 / (Nk + 1.0E-6));
            }

            @Override
            public void updateWishart(Matrix W_inv_0, Matrix W_inv_k, Matrix S_k, Vec xk, double Nk, Vec m_0, double beta_0, double beta_k, double nu_k) {
                int d = W_inv_0.rows();
                W_inv_0.copyTo(W_inv_k);
                W_inv_k.mutableAdd(Nk, S_k);
                for (int i = 0; i < d; ++i) {
                    W_inv_k.increment(i, i, 1.0E-6);
                }
                double \u03b20_Nk_over_\u03b20_plus_Nk = beta_0 * Nk / beta_k;
                Vec tmp = xk.clone();
                tmp.mutableSubtract(m_0);
                Matrix.OuterProductUpdate(W_inv_k, tmp, tmp, \u03b20_Nk_over_\u03b20_plus_Nk);
                W_inv_k.mutableMultiply(1.0 / nu_k);
            }

            @Override
            public Matrix allocate(int d) {
                return new DenseMatrix(d, d);
            }

            @Override
            public NormalM asNormal(Vec mean, Matrix cov) {
                return new NormalM(mean, cov);
            }
        };


        public abstract void fit(List<Vec> var1, Matrix var2, double[] var3, Vec var4, double var5);

        public abstract void updateWishart(Matrix var1, Matrix var2, Matrix var3, Vec var4, double var5, Vec var7, double var8, double var10, double var12);

        public abstract Matrix allocate(int var1);

        public abstract NormalM asNormal(Vec var1, Matrix var2);
    }
}

