/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import jsat.DataSet;
import jsat.DataStore;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.DataPoint;
import jsat.classifiers.DataPointPair;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.utils.IntList;
import jsat.utils.ListUtils;

public class ClassificationDataSet
extends DataSet<ClassificationDataSet> {
    protected CategoricalData predicting;
    protected IntList targets;
    private static final int[] emptyInt = new int[0];

    public ClassificationDataSet(DataSet dataSet, int predicting) {
        this(dataSet.getDataPoints(), predicting);
        int i;
        for (i = 0; i < this.getNumNumericalVars(); ++i) {
            this.numericalVariableNames.put(i, dataSet.getNumericName(i));
        }
        for (i = 0; i < dataSet.size(); ++i) {
            this.setWeight(i, dataSet.getWeight(i));
        }
    }

    public ClassificationDataSet(List<DataPoint> data, int predicting) {
        super(data.get(0).numNumericalValues(), data.get(0).getCategoricalData());
        DataPoint tmp = data.get(0);
        this.categories = new CategoricalData[tmp.numCategoricalValues() - 1];
        for (int i = 0; i < this.categories.length; ++i) {
            this.categories[i] = i >= predicting ? tmp.getCategoricalData()[i + 1] : tmp.getCategoricalData()[i];
        }
        this.datapoints.setCategoricalDataInfo(this.categories);
        this.predicting = tmp.getCategoricalData()[predicting];
        this.targets = new IntList(data.size());
        for (DataPoint dp : data) {
            int[] newCats = new int[dp.numCategoricalValues() - 1];
            int[] prevCats = dp.getCategoricalValues();
            int k = 0;
            for (int i = 0; i < prevCats.length; ++i) {
                if (i == predicting) continue;
                newCats[k++] = prevCats[i];
            }
            DataPoint newPoint = new DataPoint(dp.getNumericalValues(), newCats, this.categories);
            this.datapoints.addDataPoint(newPoint);
            this.targets.add(prevCats[predicting]);
        }
    }

    public ClassificationDataSet(DataStore datapoints, List<Integer> targets) {
        this(datapoints, targets, new CategoricalData(targets.stream().mapToInt(i -> i).max().getAsInt() + 1));
    }

    public ClassificationDataSet(DataStore datapoints, List<Integer> targets, CategoricalData predicting) {
        super(datapoints);
        this.targets = new IntList(targets);
        this.predicting = predicting;
    }

    public ClassificationDataSet(List<DataPointPair<Integer>> data, CategoricalData predicting) {
        super(data.get(0).getVector().length(), data.get(0).getDataPoint().getCategoricalData());
        this.predicting = predicting;
        this.categories = CategoricalData.copyOf(data.get(0).getDataPoint().getCategoricalData());
        this.targets = new IntList(data.size());
        for (DataPointPair<Integer> dpp : data) {
            this.datapoints.addDataPoint(dpp.getDataPoint());
            this.targets.add(dpp.getPair());
        }
    }

    public ClassificationDataSet(int numerical, CategoricalData[] categories, CategoricalData predicting) {
        super(numerical, categories);
        this.predicting = predicting;
        this.targets = new IntList();
    }

    public int getClassSize() {
        return this.predicting.getNumOfCategories();
    }

    public static ClassificationDataSet comineAllBut(List<ClassificationDataSet> list, int exception) {
        int numer = list.get(0).getNumNumericalVars();
        CategoricalData[] categories = list.get(0).getCategories();
        CategoricalData predicting = list.get(0).getPredicting();
        if (list.get(0).rowMajor()) {
            ClassificationDataSet cds = new ClassificationDataSet(numer, categories, predicting);
            for (int i = 0; i < list.size(); ++i) {
                if (i == exception) continue;
                for (int j = 0; j < list.get(i).size(); ++j) {
                    cds.datapoints.addDataPoint(list.get(i).getDataPoint(j));
                }
                cds.targets.addAll(list.get((int)i).targets);
            }
            return cds;
        }
        DataStore ds = list.get((int)0).datapoints.emptyClone();
        IntList new_targets = new IntList();
        for (int k = 0; k < list.size(); ++k) {
            if (k == exception) continue;
            Iterator<DataPoint> iter = list.get((int)k).datapoints.getRowIter();
            int pos = 0;
            while (iter.hasNext()) {
                ds.addDataPoint(iter.next());
                new_targets.add(list.get(k).getDataPointCategory(pos++));
            }
        }
        ds.finishAdding();
        return new ClassificationDataSet(ds, new_targets);
    }

    @Override
    public DataPoint getDataPoint(int i) {
        return this.getDataPointPair(i).getDataPoint();
    }

    public DataPointPair<Integer> getDataPointPair(int i) {
        if (i >= this.size()) {
            throw new IndexOutOfBoundsException("There are not that many samples in the data set");
        }
        return new DataPointPair<Integer>(this.datapoints.getDataPoint(i), this.targets.getI(i));
    }

    @Override
    public void setDataPoint(int i, DataPoint dp) {
        if (i >= this.size()) {
            throw new IndexOutOfBoundsException("There are not that many samples in the data set");
        }
        this.datapoints.setDataPoint(i, dp);
    }

    public int getDataPointCategory(int i) {
        if (i >= this.size()) {
            throw new IndexOutOfBoundsException("There are not that many samples in the data set: " + i);
        }
        if (i < 0) {
            throw new IndexOutOfBoundsException("Can not specify negative index " + i);
        }
        return this.targets.get(i);
    }

    @Override
    protected ClassificationDataSet getSubset(List<Integer> indicies) {
        if (this.datapoints.rowMajor()) {
            ClassificationDataSet newData = new ClassificationDataSet(this.numNumerVals, this.categories, this.predicting);
            for (int i : indicies) {
                newData.addDataPoint(this.getDataPoint(i), this.getDataPointCategory(i));
            }
            return newData;
        }
        int new_n = indicies.size();
        HashMap<Integer, Integer> old_indx_to_new = new HashMap<Integer, Integer>(indicies.size());
        for (int new_i = 0; new_i < indicies.size(); ++new_i) {
            old_indx_to_new.put(indicies.get(new_i), new_i);
        }
        DataStore new_ds = this.datapoints.emptyClone();
        Iterator<DataPoint> data_iter = this.datapoints.getRowIter();
        IntList new_targets = new IntList();
        int orig_pos = 0;
        while (data_iter.hasNext()) {
            DataPoint dp = data_iter.next();
            if (old_indx_to_new.containsKey(orig_pos)) {
                DataPoint new_dp = new DataPoint(dp.getNumericalValues().clone(), Arrays.copyOf(dp.getCategoricalValues(), this.getNumCategoricalVars()), this.categories);
                new_ds.addDataPoint(new_dp);
                new_targets.add(this.getDataPointCategory(orig_pos));
            }
            ++orig_pos;
        }
        new_ds.finishAdding();
        return new ClassificationDataSet(new_ds, new_targets);
    }

    public List<ClassificationDataSet> stratSet(int folds, Random rnd) {
        ArrayList<ClassificationDataSet> cvList = new ArrayList<ClassificationDataSet>();
        while (cvList.size() < folds) {
            ClassificationDataSet clone = new ClassificationDataSet(this.numNumerVals, this.categories, this.predicting.clone());
            cvList.add(clone);
        }
        IntList rndOrder = new IntList();
        int curFold = 0;
        for (int c = 0; c < this.getClassSize(); ++c) {
            List<DataPoint> subPoints = this.getSamples(c);
            rndOrder.clear();
            ListUtils.addRange(rndOrder, 0, subPoints.size(), 1);
            Collections.shuffle(rndOrder, rnd);
            Iterator iterator = rndOrder.iterator();
            while (iterator.hasNext()) {
                int i = (Integer)iterator.next();
                cvList.get(curFold).addDataPoint(subPoints.get(i), c);
                curFold = (curFold + 1) % folds;
            }
        }
        return cvList;
    }

    public void addDataPoint(Vec v, int[] classes, int classification) {
        this.addDataPoint(v, classes, classification, 1.0);
    }

    public void addDataPoint(Vec v, int classification) {
        this.addDataPoint(v, emptyInt, classification, 1.0);
    }

    public void addDataPoint(Vec v, int classification, double weight) {
        this.addDataPoint(v, emptyInt, classification, weight);
    }

    public void addDataPoint(Vec v, int[] classes, int classification, double weight) {
        if (v.length() != this.numNumerVals) {
            throw new RuntimeException("Data point does not contain enough numerical data points");
        }
        if (classes.length != this.categories.length) {
            throw new RuntimeException("Data point does not contain enough categorical data points");
        }
        for (int i = 0; i < classes.length; ++i) {
            if (this.categories[i].isValidCategory(classes[i]) || classes[i] < 0) continue;
            throw new IllegalArgumentException("Categoriy value given is invalid");
        }
        this.datapoints.addDataPointCheck(new DataPoint(v, classes, this.categories));
        this.setWeight(this.size() - 1, weight);
        this.targets.add(classification);
    }

    public void addDataPoint(DataPoint dp, int classification) {
        this.addDataPoint(dp, classification, 1.0);
    }

    public void addDataPoint(DataPoint dp, int classification, double weight) {
        if (dp.getNumericalValues().length() != this.numNumerVals) {
            throw new RuntimeException("Data point does not contain enough numerical data points");
        }
        if (dp.getCategoricalValues().length != this.categories.length) {
            throw new RuntimeException("Data point does not contain enough categorical data points");
        }
        for (int i = 0; i < dp.getCategoricalValues().length; ++i) {
            int val = dp.getCategoricalValues()[i];
            if (this.categories[i].isValidCategory(val) || val < 0) continue;
            throw new RuntimeException("Categoriy value given is invalid");
        }
        this.datapoints.addDataPointCheck(dp);
        this.targets.add(classification);
        this.setWeight(this.size() - 1, weight);
    }

    public List<DataPoint> getSamples(int category) {
        ArrayList<DataPoint> subSet = new ArrayList<DataPoint>();
        for (int i = 0; i < this.targets.size(); ++i) {
            if (this.targets.getI(i) != category) continue;
            subSet.add(this.datapoints.getDataPoint(i));
        }
        return subSet;
    }

    public Vec getSampleVariableVector(int category, int n) {
        List<DataPoint> categoryList = this.getSamples(category);
        DenseVector vec = new DenseVector(categoryList.size());
        for (int i = 0; i < vec.length(); ++i) {
            vec.set(i, categoryList.get(i).getNumericalValues().get(n));
        }
        return vec;
    }

    public CategoricalData getPredicting() {
        return this.predicting;
    }

    public List<DataPointPair<Integer>> getAsDPPList() {
        ArrayList<DataPointPair<Integer>> dataPoints = new ArrayList<DataPointPair<Integer>>(this.size());
        for (int i = 0; i < this.size(); ++i) {
            dataPoints.add(new DataPointPair<Integer>(this.datapoints.getDataPoint(i), this.targets.get(i)));
        }
        return dataPoints;
    }

    public List<DataPointPair<Double>> getAsFloatDPPList() {
        ArrayList<DataPointPair<Double>> dataPoints = new ArrayList<DataPointPair<Double>>(this.size());
        for (int i = 0; i < this.size(); ++i) {
            dataPoints.add(new DataPointPair<Double>(this.datapoints.getDataPoint(i), Double.valueOf(this.targets.getI(i))));
        }
        return dataPoints;
    }

    public double[] getPriors() {
        int i;
        double[] priors = new double[this.getClassSize()];
        double sum = 0.0;
        for (i = 0; i < this.size(); ++i) {
            double w = this.getWeight(i);
            int n = this.targets.getI(i);
            priors[n] = priors[n] + w;
            sum += w;
        }
        i = 0;
        while (i < priors.length) {
            int n = i++;
            priors[n] = priors[n] / sum;
        }
        return priors;
    }

    public int classSampleCount(int targetClass) {
        int count = 0;
        Iterator iterator = this.targets.iterator();
        while (iterator.hasNext()) {
            int i = (Integer)iterator.next();
            if (i != targetClass) continue;
            ++count;
        }
        return count;
    }

    @Override
    public int size() {
        return this.datapoints.size();
    }

    public ClassificationDataSet shallowClone() {
        ClassificationDataSet clone = new ClassificationDataSet(this.numNumerVals, this.categories, this.predicting.clone());
        for (int i = 0; i < this.size(); ++i) {
            clone.datapoints.addDataPoint(this.getDataPoint(i));
        }
        clone.targets.addAll(this.targets);
        if (this.weights != null) {
            clone.weights = Arrays.copyOf(this.weights, this.weights.length);
        }
        return clone;
    }

    public ClassificationDataSet emptyClone() {
        ClassificationDataSet clone = new ClassificationDataSet(this.numNumerVals, this.categories, this.predicting.clone());
        return clone;
    }

    @Override
    public ClassificationDataSet getTwiceShallowClone() {
        return (ClassificationDataSet)super.getTwiceShallowClone();
    }
}

