/*
 * Decompiled with CFR 0.152.
 */
package jsat.clustering.biclustering;

import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import jsat.linear.DenseMatrix;
import jsat.linear.Matrix;
import jsat.utils.Pair;
import jsat.utils.concurrent.ParallelUtils;

public class ConsensusScore {
    public static double score(boolean parallel, List<List<Integer>> rows_truth, List<List<Integer>> cols_truth, List<List<Integer>> rows_found, List<List<Integer>> cols_found) {
        int k_true = rows_truth.size();
        int k_found = rows_found.size();
        double[][] cost_matrix = new double[k_true][k_found];
        ParallelUtils.run(parallel, k_true, i -> {
            Set<Pair<Integer, Integer>> true_ci = ConsensusScore.coCluster_to_set(rows_truth, i, cols_truth);
            for (int j = 0; j < k_found; ++j) {
                Set<Pair<Integer, Integer>> true_cj = ConsensusScore.coCluster_to_set(rows_found, j, cols_found);
                int A_size = true_ci.size();
                int B_size = true_cj.size();
                true_cj.removeIf(pair -> !true_ci.contains(pair));
                int union = true_cj.size();
                cost_matrix[i][j] = 1.0 - (double)union / (double)(A_size + B_size - union);
            }
        });
        Map<Integer, Integer> assignments = ConsensusScore.assignment(new DenseMatrix(cost_matrix));
        double score_sum = 0.0;
        for (Map.Entry<Integer, Integer> pair : assignments.entrySet()) {
            score_sum += 1.0 - cost_matrix[pair.getKey()][pair.getValue()];
        }
        return score_sum / (double)Math.max(k_true, k_found);
    }

    private static Set<Pair<Integer, Integer>> coCluster_to_set(List<List<Integer>> rows_truth, int q, List<List<Integer>> cols_truth) {
        HashSet<Pair<Integer, Integer>> true_c_i = new HashSet<Pair<Integer, Integer>>();
        List<Integer> rows = rows_truth.get(q);
        List<Integer> cols = cols_truth.get(q);
        for (int i = 0; i < rows.size(); ++i) {
            for (int j = 0; j < cols.size(); ++j) {
                true_c_i.add(new Pair<Integer, Integer>(rows.get(i), cols.get(j)));
            }
        }
        return true_c_i;
    }

    private static Map<Integer, Integer> assignment(Matrix A2) {
        HashMap<Integer, Integer> assignments = new HashMap<Integer, Integer>();
        boolean[] taken = new boolean[A2.cols()];
        for (int i = 0; i < A2.rows(); ++i) {
            int min_indx = -1;
            double best_score = Double.POSITIVE_INFINITY;
            for (int j = 0; j < A2.cols(); ++j) {
                double score = A2.get(i, j);
                if (!(score < best_score) || taken[j]) continue;
                best_score = score;
                min_indx = j;
            }
            assignments.put(i, min_indx);
            taken[min_indx] = true;
            if (assignments.size() == Math.min(A2.rows(), A2.cols())) break;
        }
        return assignments;
    }
}

