/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.bayesian;

import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.DataPointPair;
import jsat.exceptions.FailedToFitException;
import jsat.exceptions.UntrainedModelException;
import jsat.utils.IntSet;

public class ConditionalProbabilityTable
implements Classifier {
    private static final long serialVersionUID = -287709075031023626L;
    private CategoricalData predicting;
    private double[] countArray;
    private Map<Integer, CategoricalData> valid;
    private int[] realIndexToCatIndex;
    private int[] catIndexToRealIndex;
    private int[] dimSize;
    private int predictingIndex;

    @Override
    public CategoricalResults classify(DataPoint data) {
        if (this.catIndexToRealIndex[this.predictingIndex] < 0) {
            throw new UntrainedModelException("CPT has not been trained for a classification problem");
        }
        CategoricalResults cr = new CategoricalResults(this.predicting.getNumOfCategories());
        int[] cord = new int[this.dimSize.length];
        this.dataPointToCord(new DataPointPair<Integer>(data, -1), this.predictingIndex, cord);
        for (int i = 0; i < cr.size(); ++i) {
            cord[this.catIndexToRealIndex[this.predictingIndex]] = i;
            cr.setProb(i, this.countArray[this.cordToIndex(cord)]);
        }
        cr.normalize();
        return cr;
    }

    public int getDimensionSize() {
        return this.dimSize.length;
    }

    public int dataPointToCord(DataPointPair<Integer> dataPoint, int targetClass, int[] cord) {
        if (cord.length != this.getDimensionSize()) {
            throw new ArithmeticException("Storage space and CPT dimension miss match");
        }
        DataPoint dp = dataPoint.getDataPoint();
        int skipVal = -1;
        for (int i = 0; i < this.dimSize.length; ++i) {
            if (this.realIndexToCatIndex[i] == targetClass) {
                skipVal = targetClass == dp.numCategoricalValues() ? dataPoint.getPair().intValue() : dp.getCategoricalValue(this.realIndexToCatIndex[i]);
            }
            cord[i] = this.realIndexToCatIndex[i] == this.predictingIndex ? dataPoint.getPair().intValue() : dp.getCategoricalValue(this.realIndexToCatIndex[i]);
        }
        return skipVal;
    }

    @Override
    public void train(ClassificationDataSet dataSet, boolean parallel) {
        this.train(dataSet);
    }

    @Override
    public void train(ClassificationDataSet dataSet) {
        IntSet all = new IntSet();
        for (int i = 0; i < dataSet.getNumCategoricalVars() + 1; ++i) {
            all.add(Integer.valueOf(i));
        }
        this.trainC(dataSet, all);
    }

    public void trainC(ClassificationDataSet dataSet, Set<Integer> categoriesToUse) {
        if (categoriesToUse.size() > dataSet.getNumFeatures() + 1) {
            throw new FailedToFitException("CPT can not train on a number of features greater then the dataset's feature count. Specified " + categoriesToUse.size() + " but data set has only " + dataSet.getNumFeatures());
        }
        CategoricalData[] tmp = dataSet.getCategories();
        this.predicting = dataSet.getPredicting();
        this.predictingIndex = dataSet.getNumCategoricalVars();
        this.valid = new HashMap<Integer, CategoricalData>();
        this.realIndexToCatIndex = new int[categoriesToUse.size()];
        this.catIndexToRealIndex = new int[dataSet.getNumCategoricalVars() + 1];
        Arrays.fill(this.catIndexToRealIndex, -1);
        this.dimSize = new int[this.realIndexToCatIndex.length];
        int flatSize = 1;
        int k = 0;
        for (int i : categoriesToUse) {
            if (i == this.predictingIndex) continue;
            CategoricalData dataInfo = tmp[i];
            flatSize *= dataInfo.getNumOfCategories();
            this.valid.put(i, dataInfo);
            this.realIndexToCatIndex[k] = i;
            this.catIndexToRealIndex[i] = k;
            this.dimSize[k++] = dataInfo.getNumOfCategories();
        }
        if (categoriesToUse.contains(this.predictingIndex)) {
            flatSize *= this.predicting.getNumOfCategories();
            this.realIndexToCatIndex[k] = this.predictingIndex;
            this.catIndexToRealIndex[this.predictingIndex] = k;
            this.dimSize[k] = this.predicting.getNumOfCategories();
            this.valid.put(this.predictingIndex, this.predicting);
        }
        this.countArray = new double[flatSize];
        Arrays.fill(this.countArray, 1.0);
        int[] cordinate = new int[this.dimSize.length];
        for (int i = 0; i < dataSet.size(); ++i) {
            DataPoint dp = dataSet.getDataPoint(i);
            for (int j = 0; j < this.realIndexToCatIndex.length; ++j) {
                cordinate[j] = this.realIndexToCatIndex[j] != this.predictingIndex ? dp.getCategoricalValue(this.realIndexToCatIndex[j]) : dataSet.getDataPointCategory(i);
            }
            int n = this.cordToIndex(cordinate);
            this.countArray[n] = this.countArray[n] + dataSet.getWeight(i);
        }
    }

    public double query(int targetClass, DataPointPair<Integer> dataPoint) {
        int[] cord = new int[this.dimSize.length];
        int skipVal = this.dataPointToCord(dataPoint, targetClass, cord);
        return this.query(targetClass, skipVal, cord);
    }

    public double query(int targetClass, int targetValue, int[] cord) {
        double sumVal = 0.0;
        double targetVal = 0.0;
        int realTargetIndex = this.catIndexToRealIndex[targetClass];
        CategoricalData queryData = this.valid.get(targetClass);
        for (int i = 0; i < queryData.getNumOfCategories(); ++i) {
            cord[realTargetIndex] = i;
            double tmp = this.countArray[this.cordToIndex(cord)];
            sumVal += tmp;
            if (i != targetValue) continue;
            targetVal = tmp;
        }
        return targetVal / sumVal;
    }

    private int cordToIndex(int ... cords) {
        if (cords.length != this.realIndexToCatIndex.length) {
            throw new RuntimeException("Something bad");
        }
        int index = 0;
        for (int i = 0; i < cords.length; ++i) {
            index = cords[i] + this.dimSize[i] * index;
        }
        return index;
    }

    private int cordToIndex(DataPointPair<Integer> dataPoint) {
        DataPoint dp = dataPoint.getDataPoint();
        int index = 0;
        for (int i = 0; i < this.dimSize.length; ++i) {
            index = dp.getCategoricalValue(this.realIndexToCatIndex[i]) + this.dimSize[i] * index;
        }
        return index;
    }

    @Override
    public boolean supportsWeightedData() {
        return true;
    }

    @Override
    public Classifier clone() {
        throw new UnsupportedOperationException("Not supported yet.");
    }
}

