/*
 * Decompiled with CFR 0.152.
 */
package jsat.clustering.biclustering;

import java.util.List;
import jsat.DataSet;
import jsat.SimpleDataSet;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.DataPoint;
import jsat.clustering.Clusterer;
import jsat.clustering.KClusterer;
import jsat.clustering.biclustering.Bicluster;
import jsat.clustering.kmeans.GMeans;
import jsat.clustering.kmeans.HamerlyKMeans;
import jsat.linear.DenseVector;
import jsat.linear.IndexValue;
import jsat.linear.Matrix;
import jsat.linear.SubMatrix;
import jsat.linear.TruncatedSVD;
import jsat.linear.Vec;
import jsat.utils.IntList;

public class SpectralCoClustering
implements Bicluster {
    private Clusterer baseClusterAlgo;
    private InputNormalization inputNormalization;

    public SpectralCoClustering() {
        this(InputNormalization.SCALE);
    }

    public SpectralCoClustering(InputNormalization normalization) {
        this(normalization, new GMeans(new HamerlyKMeans()));
    }

    public SpectralCoClustering(InputNormalization normalization, Clusterer baseCluster) {
        this.setBaseClusterAlgo(baseCluster);
        this.setInputNormalization(normalization);
    }

    public void setInputNormalization(InputNormalization inputNormalization) {
        this.inputNormalization = inputNormalization;
    }

    public InputNormalization getInputNormalization() {
        return this.inputNormalization;
    }

    public void setBaseClusterAlgo(Clusterer baseClusterAlgo) {
        this.baseClusterAlgo = baseClusterAlgo;
    }

    public Clusterer getBaseClusterAlgo() {
        return this.baseClusterAlgo;
    }

    @Override
    public void bicluster(DataSet dataSet, int clusters, boolean parallel, List<List<Integer>> row_assignments, List<List<Integer>> col_assignments) {
        Matrix A2 = dataSet.getDataMatrix();
        DenseVector R = new DenseVector(A2.rows());
        DenseVector C2 = new DenseVector(A2.cols());
        Matrix A_n = this.inputNormalization.normalize(A2, R, C2);
        int l = (int)Math.ceil(Math.log(clusters) / Math.log(2.0));
        SimpleDataSet Z = this.create_Z_dataset(A_n, l, R, C2, this.inputNormalization);
        KClusterer to_use = this.baseClusterAlgo instanceof KClusterer ? (KClusterer)this.baseClusterAlgo : new HamerlyKMeans();
        int[] joint_designations = to_use.cluster((DataSet)Z, clusters, parallel, null);
        this.createAssignments(row_assignments, col_assignments, clusters, A2, joint_designations);
    }

    public void bicluster(DataSet dataSet, boolean parallel, List<List<Integer>> row_assignments, List<List<Integer>> col_assignments) {
        Matrix A2 = dataSet.getDataMatrix();
        DenseVector R = new DenseVector(A2.rows());
        DenseVector C2 = new DenseVector(A2.cols());
        Matrix A_n = this.inputNormalization.normalize(A2, R, C2);
        int k_max = Math.min(A2.rows(), A2.cols());
        int l = (int)Math.ceil(Math.log(k_max) / Math.log(2.0));
        SimpleDataSet Z = this.create_Z_dataset(A_n, l, R, C2, this.inputNormalization);
        int[] joint_designations = this.baseClusterAlgo.cluster(Z, parallel, null);
        int clusters = 0;
        for (int i : joint_designations) {
            clusters = Math.max(clusters, i + 1);
        }
        this.createAssignments(row_assignments, col_assignments, clusters, A2, joint_designations);
    }

    private SimpleDataSet create_Z_dataset(Matrix A_n, int l, DenseVector R, DenseVector C2, InputNormalization inputNormalization) {
        int i;
        TruncatedSVD svd = new TruncatedSVD(A_n, l + 1);
        Matrix U = svd.getU();
        Matrix V = svd.getV().transpose();
        int to_skip = 1;
        U = new SubMatrix(U, 0, to_skip, U.rows(), l + to_skip);
        V = new SubMatrix(V, 0, to_skip, V.rows(), l + to_skip);
        if (inputNormalization == InputNormalization.SCALE) {
            Matrix.diagMult(R, U);
            Matrix.diagMult(C2, V);
        }
        SimpleDataSet Z = new SimpleDataSet(l, new CategoricalData[0]);
        for (i = 0; i < U.rows(); ++i) {
            Z.add(new DataPoint(U.getRow(i)));
        }
        for (i = 0; i < V.rows(); ++i) {
            Z.add(new DataPoint(V.getRow(i)));
        }
        return Z;
    }

    private void createAssignments(List<List<Integer>> row_assignments, List<List<Integer>> col_assignments, int clusters, Matrix A2, int[] joint_designations) {
        int j;
        row_assignments.clear();
        col_assignments.clear();
        for (int c = 0; c < clusters; ++c) {
            row_assignments.add(new IntList());
            col_assignments.add(new IntList());
        }
        for (int i = 0; i < A2.rows(); ++i) {
            if (joint_designations[i] < 0) continue;
            row_assignments.get(joint_designations[i]).add(i);
        }
        for (j = 0; j < A2.cols(); ++j) {
            if (joint_designations[j + A2.rows()] < 0) continue;
            col_assignments.get(joint_designations[j + A2.rows()]).add(j);
        }
        for (j = row_assignments.size() - 1; j >= 0; --j) {
            if (!row_assignments.get(j).isEmpty() && !col_assignments.get(j).isEmpty()) continue;
            row_assignments.remove(j);
            col_assignments.remove(j);
        }
    }

    @Override
    public SpectralCoClustering clone() {
        return this;
    }

    protected static Matrix row_col_normalize(Matrix A2, Vec R, Vec C2) {
        R.zeroOut();
        C2.zeroOut();
        for (int i = 0; i < A2.rows(); ++i) {
            for (IndexValue iv : A2.getRowView(i)) {
                int j = iv.getIndex();
                double v2 = iv.getValue();
                R.increment(i, v2);
                C2.increment(j, v2);
            }
        }
        R.applyFunction(v -> v == 0.0 ? 0.0 : 1.0 / Math.sqrt(v));
        C2.applyFunction(v -> v == 0.0 ? 0.0 : 1.0 / Math.sqrt(v));
        Matrix A_n = A2.clone();
        Matrix.diagMult(R, A_n);
        Matrix.diagMult(A_n, C2);
        return A_n;
    }

    public static enum InputNormalization {
        SCALE{

            @Override
            public Matrix normalize(Matrix A2, DenseVector R, DenseVector C2) {
                return SpectralCoClustering.row_col_normalize(A2, R, C2);
            }
        }
        ,
        BISTOCHASTIZATION{

            @Override
            public Matrix normalize(Matrix A2, DenseVector R, DenseVector C2) {
                DenseVector R_tmp = R.clone();
                R_tmp.zeroOut();
                R_tmp.mutableAdd(1.0);
                DenseVector C_tmp = C2.clone();
                C_tmp.zeroOut();
                C_tmp.mutableAdd(1.0);
                Matrix A_prev = A2;
                int iter = 0;
                for (double diff = Double.POSITIVE_INFINITY; iter++ < 1000 && diff > 1.0E-4; diff /= (double)A2.rows()) {
                    A_prev = A2;
                    A2 = SpectralCoClustering.row_col_normalize(A2, R, C2);
                    diff = 0.0;
                    for (int row = 0; row < A2.rows(); ++row) {
                        diff += A2.getRowView(row).pNormDist(2.0, A_prev.getRowView(row));
                    }
                    R_tmp.mutablePairwiseMultiply(R);
                    C_tmp.mutablePairwiseMultiply(C2);
                }
                R_tmp.copyTo(R);
                C_tmp.copyTo(C2);
                return A2;
            }
        };


        public abstract Matrix normalize(Matrix var1, DenseVector var2, DenseVector var3);
    }
}

