/*
 * Decompiled with CFR 0.152.
 */
package jsat.linear;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.stream.Collectors;
import jsat.DataSet;
import jsat.classifiers.DataPoint;
import jsat.distributions.ChiSquared;
import jsat.linear.DenseMatrix;
import jsat.linear.DenseVector;
import jsat.linear.IndexValue;
import jsat.linear.LUPDecomposition;
import jsat.linear.Matrix;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.MahalanobisDistance;
import jsat.utils.IndexTable;
import jsat.utils.IntList;
import jsat.utils.IntSet;
import jsat.utils.ListUtils;
import jsat.utils.Tuple3;
import jsat.utils.concurrent.ParallelUtils;
import jsat.utils.random.RandomUtil;

public class MatrixStatistics {
    private MatrixStatistics() {
    }

    public static <V extends Vec> Vec meanVector(List<V> dataSet) {
        if (dataSet.isEmpty()) {
            throw new ArithmeticException("Can not compute the mean of zero data points");
        }
        DenseVector mean = new DenseVector(((Vec)dataSet.get(0)).length());
        MatrixStatistics.meanVector((Vec)mean, dataSet);
        return mean;
    }

    public static Vec meanVector(DataSet dataSet) {
        DenseVector dv = new DenseVector(dataSet.getNumNumericalVars());
        MatrixStatistics.meanVector((Vec)dv, dataSet);
        return dv;
    }

    public static <V extends Vec> void meanVector(Vec mean, List<V> dataSet) {
        if (dataSet.isEmpty()) {
            throw new ArithmeticException("Can not compute the mean of zero data points");
        }
        if (((Vec)dataSet.get(0)).length() != mean.length()) {
            throw new ArithmeticException("Vector dimensions do not agree");
        }
        for (Vec x : dataSet) {
            mean.mutableAdd(x);
        }
        mean.mutableDivide(dataSet.size());
    }

    public static <V extends Vec> void meanVector(Vec mean, List<V> dataSet, Collection<Integer> subset) {
        if (dataSet.isEmpty()) {
            throw new ArithmeticException("Can not compute the mean of zero data points");
        }
        if (((Vec)dataSet.get(0)).length() != mean.length()) {
            throw new ArithmeticException("Vector dimensions do not agree");
        }
        for (int i : subset) {
            mean.mutableAdd((Vec)dataSet.get(i));
        }
        mean.mutableDivide(subset.size());
    }

    public static void meanVector(Vec mean, DataSet dataSet) {
        if (dataSet.size() == 0) {
            throw new ArithmeticException("Can not compute the mean of zero data points");
        }
        double sumOfWeights = 0.0;
        for (int i = 0; i < dataSet.size(); ++i) {
            DataPoint dp = dataSet.getDataPoint(i);
            double w = dataSet.getWeight(i);
            sumOfWeights += w;
            mean.mutableAdd(w, dp.getNumericalValues());
        }
        mean.mutableDivide(sumOfWeights);
    }

    public static <V extends Vec> Matrix covarianceMatrix(Vec mean, List<V> dataSet) {
        DenseMatrix coMatrix = new DenseMatrix(mean.length(), mean.length());
        MatrixStatistics.covarianceMatrix(mean, coMatrix, dataSet);
        return coMatrix;
    }

    public static <V extends Vec> void covarianceMatrix(Vec mean, Matrix covariance, List<V> dataSet) {
        if (!covariance.isSquare()) {
            throw new ArithmeticException("Storage for covariance matrix must be square");
        }
        if (covariance.rows() != mean.length()) {
            throw new ArithmeticException("Covariance Matrix size and mean size do not agree");
        }
        if (dataSet.isEmpty()) {
            throw new ArithmeticException("No data points to compute covariance from");
        }
        if (mean.length() != ((Vec)dataSet.get(0)).length()) {
            throw new ArithmeticException("Data vectors do not agree with mean and covariance matrix");
        }
        DenseVector scratch = new DenseVector(mean.length());
        for (Vec x : dataSet) {
            x.copyTo(scratch);
            scratch.mutableSubtract(mean);
            Matrix.OuterProductUpdate(covariance, scratch, scratch, 1.0);
        }
        covariance.mutableMultiply(1.0 / ((double)dataSet.size() - 1.0));
    }

    public static <V extends Vec> void covarianceMatrix(Vec mean, Matrix covariance, List<V> dataSet, Collection<Integer> subset) {
        if (!covariance.isSquare()) {
            throw new ArithmeticException("Storage for covariance matrix must be square");
        }
        if (covariance.rows() != mean.length()) {
            throw new ArithmeticException("Covariance Matrix size and mean size do not agree");
        }
        if (dataSet.isEmpty()) {
            throw new ArithmeticException("No data points to compute covariance from");
        }
        if (mean.length() != ((Vec)dataSet.get(0)).length()) {
            throw new ArithmeticException("Data vectors do not agree with mean and covariance matrix");
        }
        DenseVector scratch = new DenseVector(mean.length());
        for (int i : subset) {
            ((Vec)dataSet.get(i)).copyTo(scratch);
            scratch.mutableSubtract(mean);
            Matrix.OuterProductUpdate(covariance, scratch, scratch, 1.0);
        }
        covariance.mutableMultiply(1.0 / ((double)subset.size() - 1.0));
    }

    public static void covarianceMatrix(Vec mean, DataSet dataSet, Matrix covariance) {
        double sumOfWeights = 0.0;
        double sumOfSquaredWeights = 0.0;
        for (int i = 0; i < dataSet.size(); ++i) {
            sumOfWeights += dataSet.getWeight(i);
            sumOfSquaredWeights += Math.pow(dataSet.getWeight(i), 2.0);
        }
        MatrixStatistics.covarianceMatrix(mean, dataSet, covariance, sumOfWeights, sumOfSquaredWeights);
    }

    public static void covarianceMatrix(Vec mean, DataSet dataSet, Matrix covariance, double sumOfWeights, double sumOfSquaredWeights) {
        if (!covariance.isSquare()) {
            throw new ArithmeticException("Storage for covariance matrix must be square");
        }
        if (covariance.rows() != mean.length()) {
            throw new ArithmeticException("Covariance Matrix size and mean size do not agree");
        }
        if (dataSet.isEmpty()) {
            throw new ArithmeticException("No data points to compute covariance from");
        }
        if (mean.length() != dataSet.getNumNumericalVars()) {
            throw new ArithmeticException("Data vectors do not agree with mean and covariance matrix");
        }
        DenseVector scratch = new DenseVector(mean.length());
        for (int i = 0; i < dataSet.size(); ++i) {
            DataPoint dp = dataSet.getDataPoint(i);
            Vec x = dp.getNumericalValues();
            x.copyTo(scratch);
            scratch.mutableSubtract(mean);
            Matrix.OuterProductUpdate(covariance, scratch, scratch, dataSet.getWeight(i));
        }
        covariance.mutableMultiply(sumOfWeights / (Math.pow(sumOfWeights, 2.0) - sumOfSquaredWeights));
    }

    public static Matrix covarianceMatrix(Vec mean, DataSet dataSet) {
        DenseMatrix covariance = new DenseMatrix(mean.length(), mean.length());
        MatrixStatistics.covarianceMatrix(mean, dataSet, covariance);
        return covariance;
    }

    public static void covarianceDiag(Vec means, Vec diag, DataSet dataset) {
        int i;
        int n = dataset.size();
        int d = dataset.getNumNumericalVars();
        int[] nnzCounts = new int[d];
        double sumOfWeights = 0.0;
        for (i = 0; i < n; ++i) {
            DataPoint dp = dataset.getDataPoint(i);
            double w = dataset.getWeight(i);
            sumOfWeights += w;
            Vec x = dataset.getDataPoint(i).getNumericalValues();
            for (IndexValue iv : x) {
                int indx;
                int n2 = indx = iv.getIndex();
                nnzCounts[n2] = nnzCounts[n2] + 1;
                diag.increment(indx, w * Math.pow(iv.getValue() - means.get(indx), 2.0));
            }
        }
        for (i = 0; i < nnzCounts.length; ++i) {
            diag.increment(i, Math.pow(means.get(i), 2.0) * (double)(n - nnzCounts[i]));
        }
        diag.mutableDivide(sumOfWeights);
    }

    public static Vec covarianceDiag(Vec means, DataSet dataset) {
        DenseVector diag = new DenseVector(dataset.getNumNumericalVars());
        MatrixStatistics.covarianceDiag(means, (Vec)diag, dataset);
        return diag;
    }

    public static <V extends Vec> void covarianceDiag(Vec means, Vec diag, List<V> dataset) {
        int i;
        int n = dataset.size();
        int d = ((Vec)dataset.get(0)).length();
        int[] nnzCounts = new int[d];
        for (i = 0; i < n; ++i) {
            Vec x = (Vec)dataset.get(i);
            for (IndexValue iv : x) {
                int indx;
                int n2 = indx = iv.getIndex();
                nnzCounts[n2] = nnzCounts[n2] + 1;
                diag.increment(indx, Math.pow(iv.getValue() - means.get(indx), 2.0));
            }
        }
        for (i = 0; i < nnzCounts.length; ++i) {
            diag.increment(i, Math.pow(means.get(i), 2.0) * (double)(n - nnzCounts[i]));
        }
        diag.mutableDivide(n);
    }

    public static <V extends Vec> void covarianceDiag(Vec means, Vec diag, List<V> dataset, Collection<Integer> subset) {
        int n = subset.size();
        int d = ((Vec)dataset.get(0)).length();
        int[] nnzCounts = new int[d];
        for (int i : subset) {
            Vec x = (Vec)dataset.get(i);
            for (IndexValue iv : x) {
                int indx;
                int n2 = indx = iv.getIndex();
                nnzCounts[n2] = nnzCounts[n2] + 1;
                diag.increment(indx, Math.pow(iv.getValue() - means.get(indx), 2.0));
            }
        }
        for (int i = 0; i < nnzCounts.length; ++i) {
            diag.increment(i, Math.pow(means.get(i), 2.0) * (double)(n - nnzCounts[i]));
        }
        diag.mutableDivide(n);
    }

    public static <V extends Vec> Vec covarianceDiag(Vec means, List<V> dataset, List<Integer> subset) {
        int d = ((Vec)dataset.get(0)).length();
        DenseVector diag = new DenseVector(d);
        MatrixStatistics.covarianceDiag(means, (Vec)diag, dataset);
        return diag;
    }

    public static <V extends Vec> void FastMCD(Vec mean, Matrix cov, List<V> dataset, boolean parallel) {
        int N = dataset.size();
        int D2 = ((Vec)dataset.get(0)).length();
        int h = (int)Math.ceil((double)(N + D2 + 1) / 2.0);
        mean.zeroOut();
        cov.zeroOut();
        if (h == N) {
            MatrixStatistics.meanVector(mean, dataset);
            MatrixStatistics.covarianceMatrix(mean, cov, dataset);
            return;
        }
        double bestDet = Double.POSITIVE_INFINITY;
        Vec bestMean = null;
        Matrix bestCov = null;
        if (N <= 600) {
            List top10 = ParallelUtils.range(500, parallel).mapToObj(seed -> {
                Random rand = RandomUtil.getRandom(seed);
                Vec subset_mean = mean.clone();
                Matrix subset_cov = cov.clone();
                IntList randOrder = ListUtils.range(0, N);
                Collections.shuffle(randOrder, rand);
                IntList h_prev = new IntList(randOrder.subList(0, D2 + 1));
                MatrixStatistics.meanVector(subset_mean, dataset, h_prev);
                MatrixStatistics.covarianceMatrix(subset_mean, subset_cov, dataset, h_prev);
                double det = 0.0;
                for (int i = 0; i < 3; ++i) {
                    det = MatrixStatistics.MCD_C_step(subset_mean, subset_cov, dataset, h_prev, h, false);
                }
                return new Tuple3<Double, Vec, Matrix>(det, subset_mean, subset_cov);
            }).sorted((o1, o2) -> Double.compare((Double)o1.getX(), (Double)o2.getX())).limit(10L).collect(Collectors.toList());
            for (Tuple3 initSolution : top10) {
                double newDet;
                double prevDev = (Double)initSolution.getX();
                IntList h_prev = new IntList(h);
                Vec m = (Vec)initSolution.getY();
                Matrix c = (Matrix)initSolution.getZ();
                for (int iter = 0; iter < 20 && !(Math.abs((newDet = MatrixStatistics.MCD_C_step(m, c, dataset, h_prev, h, parallel)) - prevDev) < 1.0E-9); ++iter) {
                    prevDev = newDet;
                }
                if (!(prevDev < bestDet)) continue;
                bestCov = c;
                bestMean = m;
                bestDet = prevDev;
            }
        } else {
            int i;
            int numSplits = N >= 1500 ? 5 : (int)Math.floor((double)N / 300.0);
            IntList randOrderAll = ListUtils.range(0, N);
            Collections.shuffle(randOrderAll, RandomUtil.getLocalRandom());
            IntList[] splits = new IntList[numSplits];
            for (i = 0; i < numSplits; ++i) {
                splits[i] = new IntList();
            }
            for (i = 0; i < Math.min(1500, randOrderAll.size()); ++i) {
                splits[i % splits.length].add(randOrderAll.get(i));
            }
            int h_sub = splits[0].size() * h / N;
            List fiftySolutions = Arrays.asList(splits).stream().flatMap(split -> ParallelUtils.range(100, parallel).mapToObj(seed -> {
                Random rand = RandomUtil.getRandom(seed);
                Vec subset_mean = mean.clone();
                Matrix subset_cov = cov.clone();
                IntList randOrderSplit = new IntList((Collection<Integer>)split);
                Collections.shuffle(randOrderSplit, rand);
                IntList h_prev = new IntList(randOrderSplit.subList(0, D2 + 1));
                MatrixStatistics.meanVector(subset_mean, dataset, h_prev);
                MatrixStatistics.covarianceMatrix(subset_mean, subset_cov, dataset, h_prev);
                double det = 0.0;
                for (int i = 0; i < 3; ++i) {
                    det = MatrixStatistics.MCD_C_step(subset_mean, subset_cov, dataset, h_prev, h_sub, false);
                }
                return new Tuple3<Double, Vec, Matrix>(det, subset_mean, subset_cov);
            }).sorted((o1, o2) -> Double.compare((Double)o1.getX(), (Double)o2.getX())).limit(10L)).collect(Collectors.toList());
            IntSet splits_merged = new IntSet();
            for (int i2 = 0; i2 < splits.length; ++i2) {
                splits_merged.addAll(splits[i2]);
            }
            int h_merged = splits_merged.size() * h / N;
            List top10 = fiftySolutions.parallelStream().map(tuple -> {
                Vec subset_mean = (Vec)tuple.getY();
                Matrix subset_cov = (Matrix)tuple.getZ();
                IntList h_prev = new IntList();
                double det = 0.0;
                for (int i = 0; i < 3; ++i) {
                    det = MatrixStatistics.MCD_C_step(subset_mean, subset_cov, dataset, h_prev, h_merged, false);
                }
                return new Tuple3<Double, Vec, Matrix>(det, subset_mean, subset_cov);
            }).sorted((o1, o2) -> Double.compare((Double)o1.getX(), (Double)o2.getX())).limit(10L).collect(Collectors.toList());
            for (Tuple3 initSolution : top10) {
                double newDet;
                double prevDev = (Double)initSolution.getX();
                IntList h_prev = new IntList(h);
                Vec m = (Vec)initSolution.getY();
                Matrix c = (Matrix)initSolution.getZ();
                for (int iter = 0; iter < 20 && !(Math.abs((newDet = MatrixStatistics.MCD_C_step(m, c, dataset, h_prev, h, parallel)) - prevDev) < 1.0E-9); ++iter) {
                    prevDev = newDet;
                }
                if (!(prevDev < bestDet)) continue;
                bestCov = c;
                bestMean = m;
                bestDet = prevDev;
            }
        }
        Vec T_full = bestMean;
        Matrix S_full = bestCov;
        MahalanobisDistance md = new MahalanobisDistance();
        LUPDecomposition lup = new LUPDecomposition(S_full.clone());
        md.setInverseCovariance(lup.solve(Matrix.eye(S_full.cols())));
        ChiSquared chi = new ChiSquared(S_full.cols());
        double[] dist = new double[N];
        ParallelUtils.run(parallel, N, (start, end) -> {
            for (int i = start; i < end; ++i) {
                dist[i] = md.dist(T_full, (Vec)dataset.get(i));
            }
        });
        IndexTable it = new IndexTable(dist);
        double reScale = Math.pow(dist[it.index(N / 2)], 2.0) / chi.invCdf(0.5);
        S_full.mutableMultiply(reScale);
        int i = 0;
        while (i < N) {
            int n = i++;
            dist[n] = dist[n] / reScale;
        }
        double threshold = Math.sqrt(chi.invCdf(0.975));
        ArrayList<V> finalSet = new ArrayList<V>(N);
        for (int i3 = 0; i3 < N; ++i3) {
            if (!(dist[i3] <= threshold)) continue;
            finalSet.add(dataset.get(i3));
        }
        mean.zeroOut();
        MatrixStatistics.meanVector(mean, finalSet);
        cov.zeroOut();
        MatrixStatistics.covarianceMatrix(mean, cov, finalSet);
    }

    protected static <V extends Vec> double MCD_C_step(Vec subset_mean, Matrix subset_cov, List<V> dataset, IntList h_prev, int h, boolean parallel) {
        int N = dataset.size();
        MahalanobisDistance md = new MahalanobisDistance();
        for (int i = 0; i < subset_cov.rows(); ++i) {
            subset_cov.increment(i, i, 1.0E-4);
        }
        LUPDecomposition lup = new LUPDecomposition(subset_cov.clone());
        md.setInverseCovariance(lup.solve(Matrix.eye(subset_cov.cols())));
        double[] dists = new double[N];
        for (int i = 0; i < N; ++i) {
            dists[i] = md.dist(subset_mean, (Vec)dataset.get(i));
        }
        IndexTable it = new IndexTable(dists);
        h_prev.clear();
        for (int i = 0; i < h; ++i) {
            h_prev.add(it.index(i));
        }
        MatrixStatistics.meanVector(subset_mean, dataset, h_prev);
        MatrixStatistics.covarianceMatrix(subset_mean, subset_cov, dataset, h_prev);
        return lup.det();
    }
}

