/*
 * Decompiled with CFR 0.152.
 */
package jsat.distributions.multivariate;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import jsat.distributions.multivariate.MultivariateDistributionSkeleton;
import jsat.linear.CholeskyDecomposition;
import jsat.linear.DenseVector;
import jsat.linear.IndexValue;
import jsat.linear.Matrix;
import jsat.linear.MatrixStatistics;
import jsat.linear.SingularValueDecomposition;
import jsat.linear.Vec;

public class NormalM
extends MultivariateDistributionSkeleton {
    private static final long serialVersionUID = -7043369396743253382L;
    private double logPDFConst;
    private Matrix invCovariance;
    private Vec invCov_diag;
    private Vec mean;
    private Matrix L;
    private Vec L_diag;
    private double log_det;

    public NormalM(Vec mean, Matrix covariance) {
        this.setMeanCovariance(mean, covariance);
    }

    public NormalM(Vec mean, Vec diag_covariance) {
        this.mean = mean.clone();
        this.setCovariance(diag_covariance);
    }

    public NormalM() {
    }

    public void setMeanCovariance(Vec mean, Matrix covariance) {
        if (!covariance.isSquare()) {
            throw new ArithmeticException("Covariance matrix must be square");
        }
        if (mean.length() != covariance.rows()) {
            throw new ArithmeticException("The mean vector and matrix must have the same dimension," + mean.length() + " does not match [" + covariance.rows() + ", " + covariance.rows() + "]");
        }
        this.mean = mean.clone();
        this.setCovariance(covariance);
    }

    public void setCovariance(Matrix covMatrix) {
        if (!covMatrix.isSquare()) {
            throw new ArithmeticException("Covariance matrix must be square");
        }
        if (covMatrix.rows() != this.mean.length()) {
            throw new ArithmeticException("Covariance matrix does not agree with the mean");
        }
        CholeskyDecomposition cd = new CholeskyDecomposition(covMatrix.clone());
        this.L = cd.getLT();
        this.L.mutableTranspose();
        this.log_det = cd.getLogDet();
        int k = this.mean.length();
        if (Double.isNaN(this.log_det) || this.log_det < Math.log(1.0E-10)) {
            SingularValueDecomposition svd = new SingularValueDecomposition(covMatrix.clone());
            this.logPDFConst = 0.5 * Math.log(svd.getPseudoDet()) + (double)svd.getRank() * 0.5 * Math.log(Math.PI * 2);
            this.invCovariance = svd.getPseudoInverse();
        } else {
            this.logPDFConst = ((double)(-k) * Math.log(Math.PI * 2) - this.log_det) * 0.5;
            this.invCovariance = cd.solve(Matrix.eye(k));
        }
        this.invCov_diag = null;
        this.L_diag = null;
    }

    public void setCovariance(Vec cov_diag) {
        if (cov_diag.length() != this.mean.length()) {
            throw new ArithmeticException("Covariance matrix does not agree with the mean");
        }
        int k = this.mean.length();
        this.log_det = 0.0;
        for (IndexValue iv : cov_diag) {
            this.log_det += Math.log(iv.getValue());
        }
        this.L_diag = cov_diag.clone();
        this.L_diag.applyFunction(Math::sqrt);
        this.invCov_diag = cov_diag.clone();
        this.logPDFConst = ((double)(-k) * Math.log(Math.PI * 2) - this.log_det) * 0.5;
        this.invCov_diag.applyFunction(f -> f > 0.0 ? 1.0 / f : 0.0);
        this.invCovariance = null;
        this.L = null;
    }

    public Vec getMean() {
        return this.mean;
    }

    @Override
    public double logPdf(Vec x) {
        double xDependent;
        if (this.mean == null) {
            throw new ArithmeticException("No mean or variance set");
        }
        Vec xMinusMean = x.subtract(this.mean);
        if (this.invCov_diag != null) {
            xDependent = 0.0;
            for (IndexValue iv : xMinusMean) {
                xDependent += iv.getValue() * iv.getValue() * this.invCov_diag.get(iv.getIndex());
            }
            xDependent *= -0.5;
        } else {
            xDependent = xMinusMean.dot(this.invCovariance.multiply(xMinusMean)) * -0.5;
        }
        return this.logPDFConst + xDependent;
    }

    @Override
    public double pdf(Vec x) {
        double pdf = Math.exp(this.logPdf(x));
        if (Double.isInfinite(pdf) || Double.isNaN(pdf)) {
            return 0.0;
        }
        return pdf;
    }

    @Override
    public <V extends Vec> boolean setUsingData(List<V> dataSet, boolean parallel) {
        Vec origMean = this.mean;
        try {
            Vec newMean = MatrixStatistics.meanVector(dataSet);
            Matrix covariance = MatrixStatistics.covarianceMatrix(newMean, dataSet);
            this.mean = newMean;
            this.setCovariance(covariance);
            return true;
        }
        catch (ArithmeticException ex) {
            this.mean = origMean;
            return false;
        }
    }

    @Override
    public NormalM clone() {
        NormalM clone = new NormalM();
        if (this.invCovariance != null) {
            clone.invCovariance = this.invCovariance.clone();
        }
        if (this.mean != null) {
            clone.mean = this.mean.clone();
        }
        clone.logPDFConst = this.logPDFConst;
        return clone;
    }

    @Override
    public List<Vec> sample(int count, Random rand) {
        ArrayList<Vec> samples = new ArrayList<Vec>(count);
        DenseVector Z = new DenseVector(this.L == null ? this.L_diag.length() : this.L.rows());
        for (int i = 0; i < count; ++i) {
            for (int j = 0; j < ((Vec)Z).length(); ++j) {
                ((Vec)Z).set(j, rand.nextGaussian());
            }
            Vec sample = this.L != null ? this.L.multiply(Z) : this.L_diag.pairwiseMultiply(Z);
            sample.mutableAdd(this.mean);
            samples.add(sample);
        }
        return samples;
    }
}

