/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.WarmClassifier;
import jsat.classifiers.evaluation.ClassificationScore;
import jsat.datatransform.DataTransformProcess;
import jsat.exceptions.UntrainedModelException;
import jsat.math.OnLineStatistics;
import jsat.utils.concurrent.ParallelUtils;
import jsat.utils.random.RandomUtil;

public class ClassificationModelEvaluation {
    private Classifier classifier;
    private ClassificationDataSet dataSet;
    private boolean parallel;
    private double[][] confusionMatrix;
    private double sumOfWeights;
    private long totalTrainingTime = 0L;
    private long totalClassificationTime = 0L;
    private DataTransformProcess dtp;
    private boolean keepPredictions;
    private CategoricalResults[] predictions;
    private int[] truths;
    private double[] pointWeights;
    private OnLineStatistics errorStats;
    private Map<ClassificationScore, OnLineStatistics> scoreMap;
    private boolean keepModels = false;
    private Classifier[] keptModels;
    private Classifier[] warmModels;

    public ClassificationModelEvaluation(Classifier classifier, ClassificationDataSet dataSet) {
        this(classifier, dataSet, false);
    }

    public ClassificationModelEvaluation(Classifier classifier, ClassificationDataSet dataSet, boolean parallel) {
        this.classifier = classifier;
        this.dataSet = dataSet;
        this.parallel = parallel;
        this.dtp = new DataTransformProcess();
        this.keepPredictions = false;
        this.errorStats = new OnLineStatistics();
        this.scoreMap = new LinkedHashMap<ClassificationScore, OnLineStatistics>();
    }

    public void setKeepModels(boolean keepModels) {
        this.keepModels = keepModels;
    }

    public boolean isKeepModels() {
        return this.keepModels;
    }

    public Classifier[] getKeptModels() {
        return this.keptModels;
    }

    public void setWarmModels(Classifier ... warmModels) {
        this.warmModels = warmModels;
    }

    public void setDataTransformProcess(DataTransformProcess dtp) {
        this.dtp = dtp.clone();
    }

    public void evaluateCrossValidation(int folds) {
        this.evaluateCrossValidation(folds, RandomUtil.getRandom());
    }

    public void evaluateCrossValidation(int folds, Random rand) {
        if (folds < 2) {
            throw new UntrainedModelException("Model could not be evaluated because " + folds + " is < 2, and not valid for cross validation");
        }
        List<ClassificationDataSet> lcds = this.dataSet.cvSet(folds, rand);
        this.evaluateCrossValidation(lcds);
    }

    public void evaluateCrossValidation(List<ClassificationDataSet> lcds) {
        ArrayList<ClassificationDataSet> trainCombinations = new ArrayList<ClassificationDataSet>(lcds.size());
        for (int i = 0; i < lcds.size(); ++i) {
            trainCombinations.add(ClassificationDataSet.comineAllBut(lcds, i));
        }
        this.evaluateCrossValidation(lcds, trainCombinations);
    }

    public void evaluateCrossValidation(List<ClassificationDataSet> lcds, List<ClassificationDataSet> trainCombinations) {
        int numOfClasses = this.dataSet.getClassSize();
        this.sumOfWeights = 0.0;
        this.confusionMatrix = new double[numOfClasses][numOfClasses];
        this.totalTrainingTime = 0L;
        this.totalClassificationTime = 0L;
        if (this.keepModels) {
            this.keptModels = new Classifier[lcds.size()];
        }
        this.setUpResults(this.dataSet.size());
        int end = this.dataSet.size();
        for (int i = lcds.size() - 1; i >= 0; --i) {
            ClassificationDataSet trainSet = trainCombinations.get(i);
            ClassificationDataSet testSet = lcds.get(i);
            this.evaluationWork(trainSet, testSet, i);
            int testSize = testSet.size();
            if (this.keepPredictions) {
                System.arraycopy(this.predictions, 0, this.predictions, end - testSize, testSize);
                System.arraycopy(this.truths, 0, this.truths, end - testSize, testSize);
                System.arraycopy(this.pointWeights, 0, this.pointWeights, end - testSize, testSize);
            }
            end -= testSize;
        }
    }

    public void evaluateTestSet(ClassificationDataSet testSet) {
        if (this.keepModels) {
            this.keptModels = new Classifier[1];
        }
        int numOfClasses = this.dataSet.getClassSize();
        this.sumOfWeights = 0.0;
        this.confusionMatrix = new double[numOfClasses][numOfClasses];
        this.setUpResults(testSet.size());
        this.totalClassificationTime = 0L;
        this.totalTrainingTime = 0L;
        this.evaluationWork(this.dataSet, testSet, 0);
    }

    private void evaluationWork(ClassificationDataSet trainSet, ClassificationDataSet testSet, int index) {
        ClassificationScore score;
        DataTransformProcess curProcess = this.dtp.clone();
        if (curProcess.getNumberOfTransforms() > 0) {
            trainSet = trainSet.shallowClone();
            curProcess.learnApplyTransforms(trainSet);
        }
        Classifier classifierToUse = this.classifier.clone();
        long startTrain = System.currentTimeMillis();
        if (this.warmModels != null && classifierToUse instanceof WarmClassifier) {
            WarmClassifier wc = (WarmClassifier)classifierToUse;
            wc.train(trainSet, this.warmModels[index], this.parallel);
        } else {
            classifierToUse.train(trainSet, this.parallel);
        }
        this.totalTrainingTime += System.currentTimeMillis() - startTrain;
        if (this.keptModels != null) {
            this.keptModels[index] = classifierToUse;
        }
        double[] evalErrorStats = new double[2];
        HashMap<ClassificationScore, ClassificationScore> scoresToUpdate = new HashMap<ClassificationScore, ClassificationScore>();
        for (Map.Entry<ClassificationScore, OnLineStatistics> entry : this.scoreMap.entrySet()) {
            score = entry.getKey().clone();
            score.prepare(this.dataSet.getPredicting());
            scoresToUpdate.put(score, score);
        }
        ParallelUtils.run(this.parallel, testSet.size(), (start, end) -> {
            double localCorrect = 0.0;
            double localSumOfWeights = 0.0;
            long localClassificationTime = 0L;
            HashSet<ClassificationScore> localScores = new HashSet<ClassificationScore>();
            for (Map.Entry entry : scoresToUpdate.entrySet()) {
                localScores.add(((ClassificationScore)entry.getKey()).clone());
            }
            for (int i = start; i < end; ++i) {
                DataPoint dp = testSet.getDataPoint(i);
                dp = curProcess.transform(dp);
                double w_i = testSet.getWeight(i);
                long stratClass = System.currentTimeMillis();
                CategoricalResults result = classifierToUse.classify(dp);
                localClassificationTime += System.currentTimeMillis() - stratClass;
                for (ClassificationScore score : localScores) {
                    score.addResult(result, testSet.getDataPointCategory(i), w_i);
                }
                if (this.predictions != null) {
                    this.predictions[i] = result;
                    this.truths[i] = testSet.getDataPointCategory(i);
                    this.pointWeights[i] = w_i;
                }
                int trueCat = testSet.getDataPointCategory(i);
                double[] dArray = this.confusionMatrix[trueCat];
                synchronized (dArray) {
                    double[] dArray2 = this.confusionMatrix[trueCat];
                    int n = result.mostLikely();
                    dArray2[n] = dArray2[n] + w_i;
                }
                if (trueCat == result.mostLikely()) {
                    localCorrect += w_i;
                }
                localSumOfWeights += w_i;
            }
            double[][] dArray = this.confusionMatrix;
            synchronized (this.confusionMatrix) {
                this.totalClassificationTime += localClassificationTime;
                this.sumOfWeights += localSumOfWeights;
                evalErrorStats[0] = evalErrorStats[0] + (localSumOfWeights - localCorrect);
                evalErrorStats[1] = evalErrorStats[1] + localSumOfWeights;
                for (ClassificationScore score : localScores) {
                    ((ClassificationScore)scoresToUpdate.get(score)).addResults(score);
                }
                // ** MonitorExit[var15_14] (shouldn't be in output)
                return;
            }
        });
        this.errorStats.add(evalErrorStats[0] / evalErrorStats[1]);
        for (Map.Entry<ClassificationScore, OnLineStatistics> entry : this.scoreMap.entrySet()) {
            score = entry.getKey().clone();
            score.prepare(this.dataSet.getPredicting());
            score.addResults((ClassificationScore)scoresToUpdate.get(score));
            entry.getValue().add(score.getScore());
        }
    }

    public void addScorer(ClassificationScore scorer) {
        this.scoreMap.put(scorer, new OnLineStatistics());
    }

    public OnLineStatistics getScoreStats(ClassificationScore score) {
        return this.scoreMap.get(score);
    }

    public void keepPredictions(boolean keepPredictions) {
        this.keepPredictions = keepPredictions;
    }

    public boolean doseStoreResults() {
        return this.keepPredictions;
    }

    public CategoricalResults[] getPredictions() {
        return this.predictions;
    }

    public int[] getTruths() {
        return this.truths;
    }

    public double[] getPointWeights() {
        return this.pointWeights;
    }

    public double[][] getConfusionMatrix() {
        return this.confusionMatrix;
    }

    public void prettyPrintConfusionMatrix() {
        int i;
        CategoricalData predicting = this.dataSet.getPredicting();
        int classCount = predicting.getNumOfCategories();
        int nameLength = 10;
        for (int i2 = 0; i2 < classCount; ++i2) {
            nameLength = Math.max(nameLength, predicting.getOptionName(i2).length() + 2);
        }
        String pfx = "%-" + nameLength;
        System.out.printf(pfx + "s ", "Matrix");
        for (i = 0; i < classCount - 1; ++i) {
            System.out.printf(pfx + "s ", predicting.getOptionName(i).toUpperCase());
        }
        System.out.printf(pfx + "s\n", predicting.getOptionName(classCount - 1).toUpperCase());
        for (i = 0; i < this.confusionMatrix.length; ++i) {
            System.out.printf(pfx + "s ", predicting.getOptionName(i).toUpperCase());
            for (int j = 0; j < classCount - 1; ++j) {
                System.out.printf(pfx + "f ", this.confusionMatrix[i][j]);
            }
            System.out.printf(pfx + "f\n", this.confusionMatrix[i][classCount - 1]);
        }
    }

    public void prettyPrintClassificationScores() {
        int nameLength = 10;
        for (Map.Entry<ClassificationScore, OnLineStatistics> entry : this.scoreMap.entrySet()) {
            nameLength = Math.max(nameLength, entry.getKey().getName().length() + 2);
        }
        String pfx = "%-" + nameLength;
        for (Map.Entry<ClassificationScore, OnLineStatistics> entry : this.scoreMap.entrySet()) {
            OnLineStatistics stats = entry.getValue();
            if (stats.getMax() == stats.getMin()) {
                System.out.printf(pfx + "s %-5f\n", entry.getKey().getName(), stats.getMean());
                continue;
            }
            System.out.printf(pfx + "s %-5f (%-5f)\n", entry.getKey().getName(), stats.getMean(), stats.getStandardDeviation());
        }
    }

    public double getCorrectWeights() {
        double val = 0.0;
        for (int i = 0; i < this.confusionMatrix.length; ++i) {
            val += this.confusionMatrix[i][i];
        }
        return val;
    }

    public double getSumOfWeights() {
        return this.sumOfWeights;
    }

    public double getErrorRate() {
        return 1.0 - this.getCorrectWeights() / this.sumOfWeights;
    }

    public OnLineStatistics getErrorRateStats() {
        return this.errorStats;
    }

    public long getTotalTrainingTime() {
        return this.totalTrainingTime;
    }

    public long getTotalClassificationTime() {
        return this.totalClassificationTime;
    }

    public Classifier getClassifier() {
        return this.classifier;
    }

    private void setUpResults(int resultSize) {
        if (this.keepPredictions) {
            this.predictions = new CategoricalResults[resultSize];
            this.truths = new int[this.predictions.length];
            this.pointWeights = new double[this.predictions.length];
        } else {
            this.predictions = null;
            this.truths = null;
            this.pointWeights = null;
        }
    }
}

