/*
 * Decompiled with CFR 0.152.
 */
package jsat.linear;

import java.util.Arrays;
import java.util.Iterator;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.linear.IndexValue;
import jsat.linear.Matrix;
import jsat.linear.SparseVector;
import jsat.linear.Vec;
import jsat.utils.SystemInfo;

public class SparseMatrix
extends Matrix {
    private static final long serialVersionUID = -4087445771022578544L;
    private SparseVector[] rows;

    public SparseMatrix(int rows, int cols, int rowCapacity) {
        this.rows = new SparseVector[rows];
        for (int i = 0; i < rows; ++i) {
            this.rows[i] = new SparseVector(cols, rowCapacity);
        }
    }

    public SparseMatrix(SparseVector[] rows) {
        this.rows = rows;
        for (int i = 0; i < rows.length; ++i) {
            if (rows[i].length() == rows[0].length()) continue;
            throw new IllegalArgumentException("Row " + i + " has " + rows[i].length() + " columns instead of " + rows[0].length());
        }
    }

    public SparseMatrix(int rows, int cols) {
        this.rows = new SparseVector[rows];
        for (int i = 0; i < rows; ++i) {
            this.rows[i] = new SparseVector(cols);
        }
    }

    protected SparseMatrix(SparseMatrix toCopy) {
        this.rows = new SparseVector[toCopy.rows.length];
        for (int i = 0; i < this.rows.length; ++i) {
            this.rows[i] = toCopy.rows[i].clone();
        }
    }

    @Override
    public void mutableAdd(double c, Matrix B) {
        if (!Matrix.sameDimensions(this, B)) {
            throw new ArithmeticException("Matrices must be the same dimension to be added");
        }
        for (int i = 0; i < this.rows.length; ++i) {
            this.rows[i].mutableAdd(c, B.getRowView(i));
        }
    }

    @Override
    public void mutableAdd(final double c, final Matrix B, ExecutorService threadPool) {
        if (!Matrix.sameDimensions(this, B)) {
            throw new ArithmeticException("Matrices must be the same dimension to be added");
        }
        final CountDownLatch latch = new CountDownLatch(this.rows.length);
        int i = 0;
        while (i < this.rows.length) {
            final int ii = i++;
            threadPool.submit(new Runnable(){

                @Override
                public void run() {
                    SparseMatrix.this.rows[ii].mutableAdd(c, B.getRowView(ii));
                    latch.countDown();
                }
            });
        }
        try {
            latch.await();
        }
        catch (InterruptedException ex) {
            Logger.getLogger(SparseMatrix.class.getName()).log(Level.SEVERE, null, ex);
        }
    }

    @Override
    public void mutableAdd(double c) {
        for (SparseVector row : this.rows) {
            row.mutableAdd(c);
        }
    }

    @Override
    public void mutableAdd(final double c, ExecutorService threadPool) {
        final CountDownLatch latch = new CountDownLatch(this.rows.length);
        for (final SparseVector row : this.rows) {
            threadPool.submit(new Runnable(){

                @Override
                public void run() {
                    row.mutableAdd(c);
                    latch.countDown();
                }
            });
        }
        try {
            latch.await();
        }
        catch (InterruptedException ex) {
            Logger.getLogger(SparseMatrix.class.getName()).log(Level.SEVERE, null, ex);
        }
    }

    @Override
    public void multiply(Vec b, double z, Vec c) {
        if (this.cols() != b.length()) {
            throw new ArithmeticException("Matrix dimensions do not agree, [" + this.rows() + "," + this.cols() + "] x [" + b.length() + ",1]");
        }
        if (this.rows() != c.length()) {
            throw new ArithmeticException("Target vector dimension does not agree with matrix dimensions. Matrix has " + this.rows() + " rows but tagert has " + c.length());
        }
        for (int i = 0; i < this.rows(); ++i) {
            SparseVector row = this.rows[i];
            c.increment(i, row.dot(b) * z);
        }
    }

    @Override
    public void multiply(Matrix B, Matrix C2) {
        if (!SparseMatrix.canMultiply(this, B)) {
            throw new ArithmeticException("Matrix dimensions do not agree");
        }
        if (this.rows() != C2.rows() || B.cols() != C2.cols()) {
            throw new ArithmeticException("Target Matrix is no the correct size");
        }
        for (int i = 0; i < C2.rows(); ++i) {
            SparseVector Arowi = this.rows[i];
            Vec Crowi = C2.getRowView(i);
            for (IndexValue iv : Arowi) {
                int k = iv.getIndex();
                double a = iv.getValue();
                Vec Browk = B.getRowView(k);
                Crowi.mutableAdd(a, Browk);
            }
        }
    }

    @Override
    public void multiply(final Matrix B, Matrix C2, ExecutorService threadPool) {
        if (!SparseMatrix.canMultiply(this, B)) {
            throw new ArithmeticException("Matrix dimensions do not agree");
        }
        if (this.rows() != C2.rows() || B.cols() != C2.cols()) {
            throw new ArithmeticException("Target Matrix is no the correct size");
        }
        final CountDownLatch latch = new CountDownLatch(C2.rows());
        for (int i = 0; i < C2.rows(); ++i) {
            final SparseVector Arowi = this.rows[i];
            final Vec Crowi = C2.getRowView(i);
            threadPool.submit(new Runnable(){

                @Override
                public void run() {
                    for (IndexValue iv : Arowi) {
                        int k = iv.getIndex();
                        double a = iv.getValue();
                        Vec Browk = B.getRowView(k);
                        Crowi.mutableAdd(a, Browk);
                    }
                    latch.countDown();
                }
            });
        }
        try {
            latch.await();
        }
        catch (InterruptedException ex) {
            Logger.getLogger(SparseMatrix.class.getName()).log(Level.SEVERE, null, ex);
        }
    }

    @Override
    public void mutableMultiply(double c) {
        for (SparseVector row : this.rows) {
            row.mutableMultiply(c);
        }
    }

    @Override
    public void mutableMultiply(final double c, ExecutorService threadPool) {
        final CountDownLatch latch = new CountDownLatch(this.rows.length);
        for (final SparseVector row : this.rows) {
            threadPool.submit(new Runnable(){

                @Override
                public void run() {
                    row.mutableMultiply(c);
                    latch.countDown();
                }
            });
        }
        try {
            latch.await();
        }
        catch (InterruptedException ex) {
            Logger.getLogger(SparseMatrix.class.getName()).log(Level.SEVERE, null, ex);
        }
    }

    @Override
    public Matrix[] lup() {
        throw new UnsupportedOperationException("Not supported yet.");
    }

    @Override
    public Matrix[] lup(ExecutorService threadPool) {
        throw new UnsupportedOperationException("Not supported yet.");
    }

    @Override
    public Matrix[] qr() {
        throw new UnsupportedOperationException("Not supported yet.");
    }

    @Override
    public Matrix[] qr(ExecutorService threadPool) {
        throw new UnsupportedOperationException("Not supported yet.");
    }

    @Override
    public void mutableTranspose() {
        for (int i = 0; i < this.rows() - 1; ++i) {
            for (int j = i + 1; j < this.cols(); ++j) {
                double tmp = this.get(j, i);
                this.set(j, i, this.get(i, j));
                this.set(i, j, tmp);
            }
        }
    }

    @Override
    public void transpose(Matrix C2) {
        if (this.rows() != C2.cols() || this.cols() != C2.rows()) {
            throw new ArithmeticException("Target matrix does not have the correct dimensions");
        }
        C2.zeroOut();
        for (int row = 0; row < this.rows.length; ++row) {
            for (IndexValue iv : this.rows[row]) {
                C2.set(iv.getIndex(), row, iv.getValue());
            }
        }
    }

    @Override
    public void transposeMultiply(Matrix B, Matrix C2) {
        if (this.rows() != B.rows()) {
            throw new ArithmeticException("Matrix dimensions do not agree");
        }
        if (this.cols() != C2.rows() || B.cols() != C2.cols()) {
            throw new ArithmeticException("Destination matrix does not have matching dimensions");
        }
        SparseMatrix A2 = this;
        int kLimit = this.rows();
        for (int k = 0; k < kLimit; ++k) {
            Vec bRow_k = B.getRowView(k);
            Vec aRow_k = A2.getRowView(k);
            for (IndexValue iv : aRow_k) {
                Vec cRow_i = C2.getRowView(iv.getIndex());
                double a = iv.getValue();
                cRow_i.mutableAdd(a, bRow_k);
            }
        }
    }

    @Override
    public void transposeMultiply(Matrix B, Matrix C2, ExecutorService threadPool) {
        this.transposeMultiply(B, C2);
    }

    @Override
    public void transposeMultiply(double c, Vec b, Vec x) {
        if (this.rows() != b.length()) {
            throw new ArithmeticException("Matrix dimensions do not agree, [" + this.cols() + "," + this.rows() + "] x [" + b.length() + ",1]");
        }
        if (this.cols() != x.length()) {
            throw new ArithmeticException("Matrix dimensions do not agree with target vector");
        }
        for (IndexValue b_iv : b) {
            x.mutableAdd(c * b_iv.getValue(), this.rows[b_iv.getIndex()]);
        }
    }

    @Override
    public Vec getRowView(int r) {
        return this.rows[r];
    }

    @Override
    public double get(int i, int j) {
        return this.rows[i].get(j);
    }

    @Override
    public void set(int i, int j, double value) {
        this.rows[i].set(j, value);
    }

    @Override
    public void increment(int i, int j, double value) {
        this.rows[i].increment(j, value);
    }

    @Override
    public int rows() {
        return this.rows.length;
    }

    @Override
    public int cols() {
        return this.rows[0].length();
    }

    @Override
    public boolean isSparce() {
        return true;
    }

    @Override
    public void swapRows(int r1, int r2) {
        SparseVector tmp = this.rows[r2];
        this.rows[r2] = this.rows[r1];
        this.rows[r1] = tmp;
    }

    @Override
    public void zeroOut() {
        for (SparseVector row : this.rows) {
            ((Vec)row).zeroOut();
        }
    }

    @Override
    public SparseMatrix clone() {
        return new SparseMatrix(this);
    }

    @Override
    public long nnz() {
        int nnz = 0;
        for (SparseVector v : this.rows) {
            nnz += ((Vec)v).nnz();
        }
        return nnz;
    }

    @Override
    public void changeSize(int newRows, int newCols) {
        int i;
        if (newRows <= 0) {
            throw new ArithmeticException("Matrix must have a positive number of rows");
        }
        if (newCols <= 0) {
            throw new ArithmeticException("Matrix must have a positive number of columns");
        }
        int oldRows = this.rows.length;
        if (newCols != this.cols()) {
            for (i = 0; i < this.rows.length; ++i) {
                SparseVector row_i = this.rows[i];
                while (row_i.getLastNonZeroIndex() >= newCols) {
                    row_i.set(row_i.getLastNonZeroIndex(), 0.0);
                }
                row_i.setLength(newCols);
            }
        }
        this.rows = Arrays.copyOf(this.rows, newRows);
        for (i = oldRows; i < newRows; ++i) {
            this.rows[i] = new SparseVector(newCols);
        }
    }

    @Override
    public void multiplyTranspose(Matrix B, Matrix C2) {
        if (this.cols() != B.cols()) {
            throw new ArithmeticException("Matrix dimensions do not agree");
        }
        if (this.rows() != C2.rows() || B.rows() != C2.cols()) {
            throw new ArithmeticException("Target Matrix is no the correct size");
        }
        for (int i = 0; i < this.rows(); ++i) {
            SparseVector A_i = this.rows[i];
            for (int j = 0; j < B.rows(); ++j) {
                Vec B_j = B.getRowView(j);
                double C_ij = 0.0;
                if (!B_j.isSparse()) {
                    for (IndexValue iv : A_i) {
                        C_ij += iv.getValue() * B_j.get(iv.getIndex());
                    }
                    C2.increment(i, j, C_ij);
                    continue;
                }
                Iterator<IndexValue> A_iter = A_i.getNonZeroIterator();
                Iterator<IndexValue> B_iter = B_j.getNonZeroIterator();
                if (!B_iter.hasNext() || !A_iter.hasNext()) continue;
                IndexValue A_val = A_iter.next();
                IndexValue B_val = B_iter.next();
                while (A_val != null && B_val != null) {
                    if (A_val.getIndex() == B_val.getIndex()) {
                        C_ij += A_val.getValue() * B_val.getValue();
                        A_val = A_iter.hasNext() ? A_iter.next() : null;
                        if (B_iter.hasNext()) {
                            B_val = B_iter.next();
                            continue;
                        }
                        B_val = null;
                        continue;
                    }
                    if (A_val.getIndex() < B_val.getIndex()) {
                        if (A_iter.hasNext()) {
                            A_val = A_iter.next();
                            continue;
                        }
                        A_val = null;
                        continue;
                    }
                    if (B_iter.hasNext()) {
                        B_val = B_iter.next();
                        continue;
                    }
                    B_val = null;
                }
                C2.increment(i, j, C_ij);
            }
        }
    }

    @Override
    public void multiplyTranspose(final Matrix B, final Matrix C2, ExecutorService threadPool) {
        if (this.cols() != B.cols()) {
            throw new ArithmeticException("Matrix dimensions do not agree");
        }
        if (this.rows() != C2.rows() || B.rows() != C2.cols()) {
            throw new ArithmeticException("Target Matrix is no the correct size");
        }
        final SparseMatrix A2 = this;
        final CountDownLatch latch = new CountDownLatch(SystemInfo.LogicalCores);
        int id = 0;
        while (id < SystemInfo.LogicalCores) {
            final int ID = id++;
            threadPool.submit(new Runnable(){

                @Override
                public void run() {
                    try {
                        for (int i = ID; i < A2.rows(); i += SystemInfo.LogicalCores) {
                            SparseVector A_i = A2.rows[i];
                            for (int j = 0; j < B.rows(); ++j) {
                                Vec B_j = B.getRowView(j);
                                double C_ij = 0.0;
                                if (!B_j.isSparse()) {
                                    for (IndexValue iv : A_i) {
                                        C_ij += iv.getValue() * B_j.get(iv.getIndex());
                                    }
                                    C2.increment(i, j, C_ij);
                                    continue;
                                }
                                Iterator<IndexValue> A_iter = A_i.getNonZeroIterator();
                                Iterator<IndexValue> B_iter = B_j.getNonZeroIterator();
                                if (!B_iter.hasNext() || !A_iter.hasNext()) continue;
                                IndexValue A_val = A_iter.next();
                                IndexValue B_val = B_iter.next();
                                while (A_val != null && B_val != null) {
                                    if (A_val.getIndex() == B_val.getIndex()) {
                                        C_ij += A_val.getValue() * B_val.getValue();
                                        A_val = A_iter.hasNext() ? A_iter.next() : null;
                                        if (B_iter.hasNext()) {
                                            B_val = B_iter.next();
                                            continue;
                                        }
                                        B_val = null;
                                        continue;
                                    }
                                    if (A_val.getIndex() < B_val.getIndex()) {
                                        if (A_iter.hasNext()) {
                                            A_val = A_iter.next();
                                            continue;
                                        }
                                        A_val = null;
                                        continue;
                                    }
                                    if (B_iter.hasNext()) {
                                        B_val = B_iter.next();
                                        continue;
                                    }
                                    B_val = null;
                                }
                                C2.increment(i, j, C_ij);
                            }
                        }
                    }
                    catch (Exception ex) {
                        ex.printStackTrace();
                    }
                    System.out.println(ID + " fin");
                    latch.countDown();
                }
            });
        }
        try {
            latch.await();
        }
        catch (InterruptedException ex) {
            Logger.getLogger(SparseMatrix.class.getName()).log(Level.SEVERE, null, ex);
        }
    }
}

