/*
 * Decompiled with CFR 0.152.
 */
package jsat.linear;

import java.io.Serializable;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.linear.DenseMatrix;
import jsat.linear.LUPDecomposition;
import jsat.linear.Matrix;
import jsat.linear.Vec;
import jsat.utils.SystemInfo;

public class CholeskyDecomposition
implements Serializable {
    private static final long serialVersionUID = 8925094456733750112L;
    private Matrix L;

    public CholeskyDecomposition(Matrix A2) {
        if (!A2.isSquare()) {
            throw new ArithmeticException("Input matrix must be symmetric positive definite");
        }
        this.L = A2;
        int ROWS = A2.rows();
        for (int j = 0; j < ROWS; ++j) {
            double L_jj = this.computeLJJ(A2, j);
            this.L.set(j, j, L_jj);
            this.updateRows(j, j + 1, ROWS, 1, A2, L_jj);
        }
        this.copyUpperToLower(ROWS);
    }

    public CholeskyDecomposition(final Matrix A2, ExecutorService threadpool) {
        if (!A2.isSquare()) {
            throw new ArithmeticException("Input matrix must be symmetric positive definite");
        }
        this.L = A2;
        final int ROWS = A2.rows();
        double nextLJJ = this.computeLJJ(A2, 0);
        for (int j = 0; j < ROWS; ++j) {
            final int J = j;
            final double L_jj = nextLJJ;
            this.L.set(j, j, L_jj);
            final CountDownLatch latch = new CountDownLatch(SystemInfo.LogicalCores - 1);
            int i = 1;
            while (i < SystemInfo.LogicalCores) {
                final int ID = i++;
                threadpool.submit(new Runnable(){

                    @Override
                    public void run() {
                        CholeskyDecomposition.this.updateRows(J, J + 1 + ID, ROWS, SystemInfo.LogicalCores, A2, L_jj);
                        latch.countDown();
                    }
                });
            }
            try {
                this.updateRows(J, J + 1, ROWS, SystemInfo.LogicalCores, A2, L_jj);
                if (j + 1 < ROWS) {
                    nextLJJ = this.computeLJJ(A2, j + 1);
                }
                latch.await();
                continue;
            }
            catch (InterruptedException ex) {
                Logger.getLogger(CholeskyDecomposition.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
        this.copyUpperToLower(ROWS);
    }

    public Matrix getLT() {
        DenseMatrix LT = new DenseMatrix(this.L.rows(), this.L.cols());
        for (int i = 0; i < this.L.rows(); ++i) {
            for (int j = i; j < this.L.rows(); ++j) {
                ((Matrix)LT).set(i, j, this.L.get(i, j));
            }
        }
        return LT;
    }

    public Vec solve(Vec b) {
        Vec y = LUPDecomposition.forwardSub(this.L, b);
        Vec x = LUPDecomposition.backSub(this.L, y);
        return x;
    }

    public Matrix solve(Matrix B) {
        Matrix y = LUPDecomposition.forwardSub(this.L, B);
        Matrix x = LUPDecomposition.backSub(this.L, y);
        return x;
    }

    public Matrix solve(Matrix B, ExecutorService threadpool) {
        Matrix y = LUPDecomposition.forwardSub(this.L, B, threadpool);
        Matrix x = LUPDecomposition.backSub(this.L, y, threadpool);
        return x;
    }

    public double getDet() {
        return Math.exp(this.getLogDet());
    }

    public double getLogDet() {
        double log_det = 0.0;
        for (int i = 0; i < this.L.rows(); ++i) {
            log_det += 2.0 * Math.log(this.L.get(i, i));
        }
        return log_det;
    }

    private double computeLJJ(Matrix A2, int j) {
        double L_jj = A2.get(j, j);
        for (int k = 0; k < j; ++k) {
            L_jj -= Math.pow(this.L.get(j, k), 2.0);
        }
        double result = Math.sqrt(L_jj);
        if (Double.isNaN(result)) {
            throw new ArithmeticException("input matrix is not positive definite");
        }
        return result;
    }

    private void updateRows(int j, int start, int end, int skip, Matrix A2, double L_jj) {
        for (int i = start; i < end; i += skip) {
            double L_ij = A2.get(i, j);
            for (int k = 0; k < j; ++k) {
                L_ij -= this.L.get(i, k) * this.L.get(j, k);
            }
            this.L.set(i, j, L_ij / L_jj);
        }
    }

    private void copyUpperToLower(int ROWS) {
        for (int i = 0; i < ROWS; ++i) {
            for (int j = 0; j < i; ++j) {
                this.L.set(j, i, this.L.get(i, j));
            }
        }
    }
}

