/*
 * Decompiled with CFR 0.152.
 */
package jsat.clustering;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Stack;
import java.util.stream.Collectors;
import jsat.DataSet;
import jsat.clustering.Clusterer;
import jsat.distributions.Distribution;
import jsat.distributions.discrete.Binomial;
import jsat.distributions.multivariate.IndependentDistribution;
import jsat.distributions.multivariate.MultivariateDistribution;
import jsat.distributions.multivariate.NormalM;
import jsat.linear.CholeskyDecomposition;
import jsat.linear.ConstantVector;
import jsat.linear.DenseMatrix;
import jsat.linear.DenseVector;
import jsat.linear.Matrix;
import jsat.linear.MatrixStatistics;
import jsat.linear.SingularValueDecomposition;
import jsat.linear.Vec;
import jsat.math.OnLineStatistics;
import jsat.math.SpecialMath;
import jsat.utils.IntList;

public class BayesianHAC
implements Clusterer {
    private double alpha_prior = 1.0;
    private Distributions dist = Distributions.BERNOULLI_BETA;
    protected List<MultivariateDistribution> cluster_dists;

    public BayesianHAC() {
        this(Distributions.GAUSSIAN_DIAG);
    }

    public BayesianHAC(Distributions dist) {
        this.dist = dist;
    }

    public BayesianHAC(BayesianHAC toCopy) {
        this.alpha_prior = toCopy.alpha_prior;
        this.dist = toCopy.dist;
        if (toCopy.cluster_dists != null) {
            this.cluster_dists = toCopy.cluster_dists.stream().map(MultivariateDistribution::clone).collect(Collectors.toList());
        }
    }

    static double log_exp_sum(double log_a, double log_b) {
        if (log_b > log_a) {
            return BayesianHAC.log_exp_sum(log_b, log_a);
        }
        return log_a + Math.log1p(Math.exp(log_b - log_a));
    }

    @Override
    public int[] cluster(DataSet dataSet, boolean parallel, int[] designations) {
        List<Vec> data = dataSet.getDataVectors();
        if (designations == null) {
            designations = new int[data.size()];
        }
        Object priors = null;
        ArrayList<Node> current_nodes = new ArrayList<Node>();
        for (int i = 0; i < data.size(); ++i) {
            Node n = this.dist.init(i, this.alpha_prior, data);
            if (priors == null) {
                priors = n.computeInitialPrior(data);
            }
            n.logR(data, priors);
            current_nodes.add(n);
        }
        while (current_nodes.size() > 1) {
            double best_r = Double.NEGATIVE_INFINITY;
            int best_i = -1;
            int best_j = -1;
            Node best_merged = null;
            for (int i = 0; i < current_nodes.size(); ++i) {
                Node D_i = (Node)current_nodes.get(i);
                for (int j = i + 1; j < current_nodes.size(); ++j) {
                    Node D_j = (Node)current_nodes.get(j);
                    Node merged = D_i.merge(D_i, D_j, this.alpha_prior);
                    double log_r = merged.logR(data, priors);
                    if (!(log_r > best_r)) continue;
                    best_i = i;
                    best_j = j;
                    best_merged = merged;
                    best_r = log_r;
                }
            }
            if (!(best_r > Math.log(0.5))) break;
            current_nodes.remove(best_j);
            current_nodes.remove(best_i);
            current_nodes.add(best_merged);
        }
        this.cluster_dists = new ArrayList<MultivariateDistribution>(current_nodes.size());
        for (int class_id = 0; class_id < current_nodes.size(); ++class_id) {
            List<Integer> owned = ((Node)current_nodes.get(class_id)).ownedList();
            for (int pos : owned) {
                designations[pos] = class_id;
            }
            this.cluster_dists.add(((Node)current_nodes.get(class_id)).toDistribution(data));
        }
        return designations;
    }

    public List<MultivariateDistribution> getClusterDistributions() {
        return this.cluster_dists;
    }

    @Override
    public BayesianHAC clone() {
        return this;
    }

    protected static class NormalNode
    extends Node<NormalNode, WishartFull> {
        Matrix XT_X;
        Vec x_sum;

        public NormalNode(int single_point, double alpha_prior, List<Vec> dataset) {
            super(single_point, alpha_prior);
            Vec x_i = dataset.get(single_point);
            this.XT_X = new DenseMatrix(x_i.length(), x_i.length());
            Matrix.OuterProductUpdate(this.XT_X, x_i, x_i, 1.0);
            this.x_sum = x_i;
        }

        public NormalNode(NormalNode a, NormalNode b, double alpha_prior) {
            super(a, b, alpha_prior);
            this.XT_X = a.XT_X.add(b.XT_X);
            this.x_sum = a.x_sum.add(b.x_sum);
        }

        @Override
        public NormalNode merge(NormalNode a, NormalNode b, double alpha_prior) {
            NormalNode node = new NormalNode(a, b, alpha_prior);
            return node;
        }

        @Override
        public WishartFull computeInitialPrior(List<Vec> dataset) {
            return new WishartFull(dataset);
        }

        @Override
        public MultivariateDistribution toDistribution(List<Vec> dataset) {
            List<Integer> ids = this.ownedList();
            DenseVector mean = new DenseVector(dataset.get(0).length());
            MatrixStatistics.meanVector(mean, dataset, ids);
            DenseMatrix cov = new DenseMatrix(((Vec)mean).length(), ((Vec)mean).length());
            MatrixStatistics.covarianceMatrix(mean, cov, dataset, ids);
            return new NormalM((Vec)mean, cov);
        }

        @Override
        public double log_null(List<Vec> dataset, WishartFull priors) {
            int N = this.size;
            double r = priors.r;
            int k = priors.m.length();
            double v = priors.v;
            Matrix S_prime = priors.S.add(this.XT_X);
            Matrix.OuterProductUpdate(S_prime, priors.m, priors.m, r * (double)N / ((double)N + r));
            Matrix.OuterProductUpdate(S_prime, this.x_sum, this.x_sum, -1.0 / ((double)N + r));
            Matrix.OuterProductUpdate(S_prime, priors.m, this.x_sum, -r / ((double)N + r));
            Matrix.OuterProductUpdate(S_prime, this.x_sum, priors.m, -r / ((double)N + r));
            double v_p = priors.v + (double)N;
            CholeskyDecomposition cd = new CholeskyDecomposition(S_prime);
            double log_det_S_p = cd.getLogDet();
            double log_prob = priors.log_shared_term + -v_p / 2.0 * log_det_S_p;
            for (int j = 1; j <= k; ++j) {
                log_prob += SpecialMath.lnGamma((v_p + 1.0 - (double)j) / 2.0) - SpecialMath.lnGamma((v + 1.0 - (double)j) / 2.0);
            }
            log_prob += v_p * (double)k / 2.0 * Math.log(2.0) - v * (double)k / 2.0 * Math.log(2.0);
            log_prob += (double)(-N * k) / 2.0 * Math.log(Math.PI * 2);
            return log_prob += (double)k / 2.0 * (Math.log(r) - Math.log((double)N + r));
        }
    }

    protected static class NormalDiagNode
    extends Node<NormalDiagNode, WishartDiag> {
        Vec XT_X;
        Vec x_sum;

        public NormalDiagNode(int single_point, double alpha_prior, List<Vec> dataset) {
            super(single_point, alpha_prior);
            Vec x_i = dataset.get(single_point);
            this.XT_X = x_i.pairwiseMultiply(x_i);
            this.x_sum = x_i;
        }

        public NormalDiagNode(NormalDiagNode a, NormalDiagNode b, double alpha_prior) {
            super(a, b, alpha_prior);
            this.XT_X = a.XT_X.add(b.XT_X);
            this.x_sum = a.x_sum.add(b.x_sum);
        }

        @Override
        public NormalDiagNode merge(NormalDiagNode a, NormalDiagNode b, double alpha_prior) {
            NormalDiagNode node = new NormalDiagNode(a, b, alpha_prior);
            return node;
        }

        @Override
        public WishartDiag computeInitialPrior(List<Vec> dataset) {
            return new WishartDiag(dataset);
        }

        @Override
        public MultivariateDistribution toDistribution(List<Vec> dataset) {
            List<Integer> ids = this.ownedList();
            DenseVector mean = new DenseVector(dataset.get(0).length());
            MatrixStatistics.meanVector(mean, dataset, ids);
            DenseVector cov = new DenseVector(((Vec)mean).length());
            MatrixStatistics.covarianceDiag(mean, cov, dataset, ids);
            return new NormalM((Vec)mean, cov);
        }

        @Override
        public double log_null(List<Vec> dataset, WishartDiag priors) {
            int N = this.size;
            double r = priors.r;
            int k = priors.m.length();
            double v = priors.v;
            Vec S_prime = priors.S.add(this.XT_X);
            Vec mm = priors.m.pairwiseMultiply(priors.m);
            S_prime.mutableAdd(r * (double)N / ((double)N + r), mm);
            Vec xsum_xsum = this.x_sum.pairwiseMultiply(this.x_sum);
            S_prime.mutableAdd(-1.0 / ((double)N + r), xsum_xsum);
            Vec mxsum = priors.m.pairwiseMultiply(this.x_sum).multiply(2.0);
            S_prime.mutableAdd(-r / ((double)N + r), mxsum);
            double v_p = priors.v + (double)N;
            double log_det_S_p = 0.0;
            for (int i = 0; i < S_prime.length(); ++i) {
                log_det_S_p += Math.log(S_prime.get(i));
            }
            double log_prob = priors.log_shared_term + -v_p / 2.0 * log_det_S_p;
            for (int j = 1; j <= k; ++j) {
                log_prob += SpecialMath.lnGamma((v_p + 1.0 - (double)j) / 2.0) - SpecialMath.lnGamma((v + 1.0 - (double)j) / 2.0);
            }
            log_prob += v_p * (double)k / 2.0 * Math.log(2.0) - v * (double)k / 2.0 * Math.log(2.0);
            log_prob += (double)(-N * k) / 2.0 * Math.log(Math.PI * 2);
            return log_prob += (double)k / 2.0 * (Math.log(r) - Math.log((double)N + r));
        }
    }

    protected static class BernoulliBetaNode
    extends Node<BernoulliBetaNode, BetaConjugate> {
        public Vec m;

        public BernoulliBetaNode(int single_point, double alpha_prior, List<Vec> dataset) {
            super(single_point, alpha_prior);
            this.m = dataset.get(single_point);
        }

        public BernoulliBetaNode(BernoulliBetaNode a, BernoulliBetaNode b, double alpha_prior) {
            super(a, b, alpha_prior);
            this.m = a.m.add(b.m);
        }

        @Override
        public BetaConjugate computeInitialPrior(List<Vec> dataset) {
            return new BetaConjugate(dataset);
        }

        @Override
        public double log_null(List<Vec> dataset, BetaConjugate priors) {
            Vec alpha = priors.alpha_prior;
            Vec beta = priors.beta_prior;
            int N = this.size;
            int D2 = dataset.get(0).length();
            double log_prob = 0.0;
            for (int d = 0; d < D2; ++d) {
                double a_d = alpha.get(d);
                double b_d = beta.get(d);
                double m_d = this.m.get(d);
                double log_numer = SpecialMath.lnGamma(a_d + b_d) + SpecialMath.lnGamma(a_d + m_d) + SpecialMath.lnGamma(b_d + (double)N - m_d);
                double log_denom = SpecialMath.lnGamma(a_d) + SpecialMath.lnGamma(b_d) + SpecialMath.lnGamma(a_d + b_d + (double)N);
                log_prob += log_numer - log_denom;
            }
            return log_prob;
        }

        @Override
        public BernoulliBetaNode merge(BernoulliBetaNode a, BernoulliBetaNode b, double alpha_prior) {
            return new BernoulliBetaNode(a, b, alpha_prior);
        }

        @Override
        public MultivariateDistribution toDistribution(List<Vec> dataset) {
            ArrayList<Distribution> dists = new ArrayList<Distribution>();
            double N = this.size;
            for (int i = 0; i < this.m.length(); ++i) {
                dists.add(new Binomial(1, this.m.get(i) / N));
            }
            return new IndependentDistribution(dists);
        }
    }

    protected static class WishartFull
    implements DistPrior {
        double v;
        double r;
        Vec m;
        Matrix S;
        double log_shared_term;

        public WishartFull(List<Vec> dataset) {
            int N = dataset.size();
            int k = dataset.get(0).length();
            this.v = k;
            this.r = 0.001;
            this.m = new DenseVector(k);
            MatrixStatistics.meanVector(this.m, dataset);
            this.S = new DenseMatrix(k, k);
            MatrixStatistics.covarianceMatrix(this.m, this.S, dataset);
            SingularValueDecomposition svd = new SingularValueDecomposition(this.S.clone());
            if (svd.isFullRank()) {
                this.S.mutableMultiply(0.05);
            } else {
                OnLineStatistics var = new OnLineStatistics();
                for (Vec v : dataset) {
                    for (int i = 0; i < v.length(); ++i) {
                        var.add(v.get(i));
                    }
                }
                for (int i = 0; i < this.S.rows(); ++i) {
                    this.S.increment(i, i, 0.1 * this.S.get(i, i) + var.getVarance());
                }
            }
            this.log_shared_term = 0.0;
            CholeskyDecomposition cd = new CholeskyDecomposition(this.S.clone());
            double log_det_S = cd.getLogDet();
            this.log_shared_term += this.v / 2.0 * log_det_S;
        }
    }

    protected static class WishartDiag
    implements DistPrior {
        double v;
        double r;
        Vec m;
        Vec S;
        double log_shared_term;

        public WishartDiag(List<Vec> dataset) {
            int N = dataset.size();
            int k = dataset.get(0).length();
            this.v = k;
            this.r = 0.001;
            this.m = new DenseVector(k);
            MatrixStatistics.meanVector(this.m, dataset);
            this.S = new DenseVector(k);
            MatrixStatistics.covarianceDiag(this.m, this.S, dataset);
            this.S.mutableDivide(20.0);
            this.log_shared_term = 0.0;
            double log_det_S = 0.0;
            for (int i = 0; i < k; ++i) {
                log_det_S += Math.log(this.S.get(i));
            }
            this.log_shared_term += this.v / 2.0 * log_det_S;
        }
    }

    protected static class BetaConjugate
    implements DistPrior {
        public Vec alpha_prior;
        public Vec beta_prior;

        public BetaConjugate(List<Vec> dataset) {
            int d = dataset.get(0).length();
            Vec mean = MatrixStatistics.meanVector(dataset);
            this.alpha_prior = mean.multiply(2.0).add(0.001);
            this.beta_prior = new DenseVector(new ConstantVector(1.0, d)).subtract(mean).multiply(2.0).add(0.001);
        }
    }

    protected static abstract class Node<Distribution extends Node, HyperParams extends DistPrior> {
        int owned;
        IntList allChilds;
        double log_d;
        double log_pi;
        double log_pdt;
        Distribution left_child;
        Distribution right_child;
        int size;

        public Node(int single_point, double alpha_prior) {
            this.owned = single_point;
            this.allChilds = IntList.view(new int[]{single_point});
            this.log_pdt = 1.0;
            this.size = 1;
            this.log_d = Math.log(alpha_prior);
            this.log_pi = Math.log(1.0);
        }

        public Node(Distribution a, Distribution b, double alpha_prior) {
            this.owned = -1;
            this.log_pdt = Double.NaN;
            this.size = ((Node)a).size + ((Node)b).size;
            this.allChilds = new IntList(((Node)a).allChilds);
            this.allChilds.addAll(((Node)b).allChilds);
            Collections.sort(this.allChilds);
            double tmp = Math.log(alpha_prior) + SpecialMath.lnGamma(this.size);
            this.log_d = BayesianHAC.log_exp_sum(tmp, ((Node)a).log_d + ((Node)b).log_d);
            this.log_pi = tmp - this.log_d;
            this.left_child = a;
            this.right_child = b;
        }

        public double logR(List<Vec> dataset, HyperParams priors) {
            if (this.size == 1) {
                this.log_pdt = this.log_null(dataset, priors);
                return 1.0;
            }
            double log_numer = this.log_pi + this.log_null(dataset, priors);
            double log_neg_pi = Math.log(-Math.expm1(this.log_pi));
            double log_rhs = log_neg_pi + ((Node)this.left_child).log_pdt + ((Node)this.right_child).log_pdt;
            this.log_pdt = BayesianHAC.log_exp_sum(log_numer, log_rhs);
            return log_numer - this.log_pdt;
        }

        public abstract Distribution merge(Distribution var1, Distribution var2, double var3);

        public abstract HyperParams computeInitialPrior(List<Vec> var1);

        public abstract MultivariateDistribution toDistribution(List<Vec> var1);

        public boolean isLeaf() {
            return this.right_child == null && this.left_child == null;
        }

        public abstract double log_null(List<Vec> var1, HyperParams var2);

        public Iterator<Integer> indxIter() {
            final Stack<Node> remains = new Stack<Node>();
            remains.push(this);
            return new Iterator<Integer>(){

                @Override
                public boolean hasNext() {
                    while (!remains.isEmpty() && !((Node)remains.peek()).isLeaf()) {
                        Node c = (Node)remains.pop();
                        remains.push(c.left_child);
                        remains.push(c.right_child);
                    }
                    return !remains.empty();
                }

                @Override
                public Integer next() {
                    Node c = (Node)remains.pop();
                    return c.owned;
                }
            };
        }

        public List<Integer> ownedList() {
            IntList a = new IntList(this.size);
            Iterator<Integer> iter = this.indxIter();
            while (iter.hasNext()) {
                a.add(iter.next());
            }
            return a;
        }
    }

    protected static interface DistPrior {
    }

    public static enum Distributions {
        BERNOULLI_BETA{

            @Override
            public Node init(int point, double alpha_prior, List<Vec> data) {
                return new BernoulliBetaNode(point, alpha_prior, data);
            }
        }
        ,
        GAUSSIAN_DIAG{

            @Override
            public Node init(int point, double alpha_prior, List<Vec> data) {
                return new NormalDiagNode(point, alpha_prior, data);
            }
        }
        ,
        GAUSSIAN_FULL{

            @Override
            public Node init(int point, double alpha_prior, List<Vec> data) {
                return new NormalNode(point, alpha_prior, data);
            }
        };


        public abstract Node init(int var1, double var2, List<Vec> var4);
    }
}

