/*
 * Decompiled with CFR 0.152.
 */
package jsat;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.stream.IntStream;
import jsat.DataStore;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.DataPoint;
import jsat.datatransform.DataTransform;
import jsat.datatransform.FixedDataTransform;
import jsat.datatransform.InPlaceTransform;
import jsat.linear.ConstantVector;
import jsat.linear.DenseMatrix;
import jsat.linear.DenseVector;
import jsat.linear.IndexValue;
import jsat.linear.Matrix;
import jsat.linear.MatrixOfVecs;
import jsat.linear.MatrixStatistics;
import jsat.linear.SparseMatrix;
import jsat.linear.SparseVector;
import jsat.linear.Vec;
import jsat.math.OnLineStatistics;
import jsat.utils.IntList;
import jsat.utils.ListUtils;
import jsat.utils.concurrent.ParallelUtils;
import jsat.utils.random.RandomUtil;

public abstract class DataSet<Type extends DataSet> {
    protected int numNumerVals;
    protected CategoricalData[] categories;
    protected Map<Integer, String> numericalVariableNames;
    protected DataStore datapoints;
    protected double[] weights;

    public DataSet(DataStore datapoints) {
        this.datapoints = datapoints;
        this.numNumerVals = datapoints.numNumeric();
        this.categories = datapoints.getCategoricalDataInfo();
        this.weights = null;
        if (this.numNumerVals == 0 && (this.categories == null || this.categories.length == 0)) {
            throw new IllegalArgumentException("Input must have a non-zero number of features defined");
        }
        this.numericalVariableNames = new HashMap<Integer, String>();
    }

    public DataSet(int numerical, CategoricalData[] categories) {
        this.categories = categories;
        this.numNumerVals = numerical;
        this.datapoints = DataStore.DEFAULT_STORE.emptyClone();
        this.datapoints.setNumNumeric(numerical);
        this.datapoints.setCategoricalDataInfo(categories);
        this.numericalVariableNames = new HashMap<Integer, String>();
        this.weights = null;
    }

    public void setDataStore(DataStore store) {
        if (store.size() > 0) {
            throw new RuntimeException("A non-empty data store was provided to an already existing dataset object.");
        }
        store.setCategoricalDataInfo(this.datapoints.getCategoricalDataInfo());
        store.setNumNumeric(this.numNumerVals);
        if (this.datapoints.size() > 0) {
            for (int i = 0; i < this.datapoints.size(); ++i) {
                store.addDataPoint(this.getDataPoint(i));
            }
        }
        this.datapoints = store;
    }

    public boolean rowMajor() {
        return this.datapoints.rowMajor();
    }

    public boolean setNumericName(String name, int i) {
        if (i >= this.getNumNumericalVars() || i < 0) {
            return false;
        }
        this.numericalVariableNames.put(i, name);
        return true;
    }

    public String getNumericName(int i) {
        if (i < this.getNumNumericalVars() && i >= 0) {
            return this.numericalVariableNames.getOrDefault(i, "Numeric Feature " + i);
        }
        throw new IndexOutOfBoundsException("Can not acces variable for invalid index  " + i);
    }

    public String getCategoryName(int i) {
        if (i < this.getNumCategoricalVars() && i >= 0) {
            return this.categories[i].getCategoryName();
        }
        throw new IndexOutOfBoundsException("Can not acces variable for invalid index  " + i);
    }

    public void applyTransform(DataTransform dt) {
        this.applyTransform(dt, false);
    }

    public void applyTransform(FixedDataTransform dt) {
        this.applyTransform(dt, false);
    }

    public void applyTransform(DataTransform dt, boolean parallel) {
        this.applyTransformMutate(dt, false, parallel);
    }

    public void applyTransform(final FixedDataTransform dt, boolean parallel) {
        this.applyTransformMutate(new DataTransform(){

            @Override
            public DataPoint transform(DataPoint dp) {
                return dt.transform(dp);
            }

            @Override
            public void fit(DataSet data) {
            }

            @Override
            public DataTransform clone() {
                return this;
            }
        }, false, parallel);
    }

    public void applyTransformMutate(DataTransform dt, boolean mutate) {
        this.applyTransformMutate(dt, mutate, false);
    }

    public void applyTransformMutate(DataTransform dt, boolean mutate, boolean parallel) {
        if (mutate && dt instanceof InPlaceTransform) {
            InPlaceTransform ipt = (InPlaceTransform)dt;
            ParallelUtils.range(this.size(), parallel).forEach(i -> ipt.mutableTransform(this.getDataPoint(i)));
        } else {
            ParallelUtils.range(this.size(), parallel).forEach(i -> this.setDataPoint(i, dt.transform(this.getDataPoint(i))));
            this.datapoints.setNumNumeric(this.getDataPoint(0).numNumericalValues());
            this.datapoints.setCategoricalDataInfo(this.getDataPoint(0).getCategoricalData());
        }
        this.numNumerVals = this.getDataPoint(0).numNumericalValues();
        this.categories = this.getDataPoint(0).getCategoricalData();
        if (this.numericalVariableNames != null) {
            this.numericalVariableNames.clear();
        }
    }

    public void replaceNumericFeatures(List<Vec> newNumericFeatures) {
        if (this.size() != newNumericFeatures.size()) {
            throw new RuntimeException("Input list does not have the same not of dataums as the dataset");
        }
        for (int i = 0; i < newNumericFeatures.size(); ++i) {
            DataPoint dp_i = this.getDataPoint(i);
            this.setDataPoint(i, new DataPoint(newNumericFeatures.get(i), dp_i.getCategoricalValues(), dp_i.getCategoricalData()));
        }
        this.numNumerVals = this.getDataPoint(0).numNumericalValues();
        if (this.numericalVariableNames != null) {
            this.numericalVariableNames.clear();
        }
    }

    protected void base_add(DataPoint dp, double weight) {
        this.datapoints.addDataPoint(dp);
        this.setWeight(this.size() - 1, weight);
    }

    public DataPoint getDataPoint(int i) {
        return this.datapoints.getDataPoint(i);
    }

    public void setDataPoint(int i, DataPoint dp) {
        this.datapoints.setDataPoint(i, dp);
    }

    public OnLineStatistics[] getOnlineColumnStats(boolean useWeights) {
        OnLineStatistics[] stats = new OnLineStatistics[this.numNumerVals];
        for (int i = 0; i < stats.length; ++i) {
            stats[i] = new OnLineStatistics();
        }
        double totalSoW = 0.0;
        double[] nanWeight = new double[this.numNumerVals];
        int pos = 0;
        Iterator<DataPoint> iter = this.getDataPointIterator();
        while (iter.hasNext()) {
            DataPoint dp = iter.next();
            double weight = useWeights ? this.getWeight(pos++) : 1.0;
            totalSoW += weight;
            Vec v = dp.getNumericalValues();
            for (IndexValue iv : v) {
                if (Double.isNaN(iv.getValue())) {
                    int n = iv.getIndex();
                    nanWeight[n] = nanWeight[n] + weight;
                    continue;
                }
                stats[iv.getIndex()].add(iv.getValue(), weight);
            }
        }
        double expected = totalSoW;
        for (int i = 0; i < stats.length; ++i) {
            stats[i].add(0.0, expected - stats[i].getSumOfWeights() - nanWeight[i]);
        }
        return stats;
    }

    public OnLineStatistics getOnlineDenseStats() {
        OnLineStatistics stats = new OnLineStatistics();
        double N = this.getNumNumericalVars();
        for (int i = 0; i < this.size(); ++i) {
            stats.add((double)this.getDataPoint(i).getNumericalValues().nnz() / N);
        }
        return stats;
    }

    public Vec[] getColumnMeanVariance() {
        int d = this.getNumNumericalVars();
        Vec[] vecs = new Vec[]{new DenseVector(d), new DenseVector(d)};
        Vec means = vecs[0];
        Vec stdDevs = vecs[1];
        MatrixStatistics.meanVector(means, this);
        MatrixStatistics.covarianceDiag(means, stdDevs, this);
        return vecs;
    }

    public Iterator<DataPoint> getDataPointIterator() {
        return this.datapoints.getRowIter();
    }

    public int size() {
        return this.datapoints.size();
    }

    public boolean isEmpty() {
        return this.size() == 0;
    }

    public int getSampleSize() {
        return this.size();
    }

    public int getNumCategoricalVars() {
        return this.categories.length;
    }

    public int getNumNumericalVars() {
        return this.numNumerVals;
    }

    public CategoricalData[] getCategories() {
        return this.categories;
    }

    protected abstract Type getSubset(List<Integer> var1);

    public Type getMissingDropped() {
        IntList hasNoMissing = new IntList();
        for (int i = 0; i < this.size(); ++i) {
            DataPoint dp = this.getDataPoint(i);
            boolean missing = dp.getNumericalValues().countNaNs() > 0;
            for (int c : dp.getCategoricalValues()) {
                if (c >= 0) continue;
                missing = true;
            }
            if (missing) continue;
            hasNoMissing.add(Integer.valueOf(i));
        }
        return this.getSubset(hasNoMissing);
    }

    public List<Type> randomSplit(Random rand, double ... splits) {
        if (splits.length < 1) {
            throw new IllegalArgumentException("Input array of split fractions must be non-empty");
        }
        IntList randOrder = new IntList(this.size());
        ListUtils.addRange(randOrder, 0, this.size(), 1);
        Collections.shuffle(randOrder, rand);
        int[] stops = new int[splits.length];
        double sum = 0.0;
        for (int i = 0; i < splits.length; ++i) {
            if ((sum += splits[i]) >= 1.001) {
                throw new IllegalArgumentException("Input splits sum is greater than 1 by index " + i + " reaching a sum of " + sum);
            }
            stops[i] = (int)Math.round(sum * (double)randOrder.size());
        }
        ArrayList<Type> datasets = new ArrayList<Type>(splits.length);
        int prev = 0;
        for (int i = 0; i < stops.length; ++i) {
            List<Integer> subList = randOrder.subList(prev, stops[i]);
            if (!this.rowMajor()) {
                Collections.sort(subList);
            }
            datasets.add(this.getSubset(subList));
            prev = stops[i];
        }
        return datasets;
    }

    public List<Type> randomSplit(double ... splits) {
        return this.randomSplit(RandomUtil.getRandom(), splits);
    }

    public List<Type> cvSet(int folds, Random rand) {
        double[] splits = new double[folds];
        Arrays.fill(splits, 1.0 / (double)folds);
        return this.randomSplit(rand, splits);
    }

    public List<Type> cvSet(int folds) {
        return this.cvSet(folds, RandomUtil.getRandom());
    }

    public List<DataPoint> getDataPoints() {
        ArrayList<DataPoint> list = new ArrayList<DataPoint>(this.size());
        for (int i = 0; i < this.size(); ++i) {
            list.add(this.getDataPoint(i));
        }
        return list;
    }

    public List<Vec> getDataVectors() {
        ArrayList<Vec> vecs = new ArrayList<Vec>(this.size());
        for (int i = 0; i < this.size(); ++i) {
            vecs.add(this.getDataPoint(i).getNumericalValues());
        }
        return vecs;
    }

    public Vec getNumericColumn(int i) {
        return this.datapoints.getNumericColumn(i);
    }

    public long countMissingValues() {
        long missing = 0L;
        if (this.rowMajor()) {
            for (int i = 0; i < this.size(); ++i) {
                DataPoint dp = this.getDataPoint(i);
                missing += (long)dp.getNumericalValues().countNaNs();
                for (int c : dp.getCategoricalValues()) {
                    if (c >= 0) continue;
                    ++missing;
                }
            }
        } else {
            int j;
            for (j = 0; j < this.getNumNumericalVars(); ++j) {
                missing += (long)this.datapoints.getNumericColumn(j).countNaNs();
            }
            for (j = 0; j < this.getNumCategoricalVars(); ++j) {
                missing += IntStream.of(this.datapoints.getCatColumn(j)).filter(z -> z < 0).count();
            }
        }
        return missing;
    }

    public Vec[] getNumericColumns() {
        return this.getNumericColumns(Collections.EMPTY_SET);
    }

    public Vec[] getNumericColumns(Set<Integer> skipColumns) {
        return this.datapoints.getNumericColumns(skipColumns);
    }

    public Matrix getDataMatrix() {
        if (this.size() > 0 && this.getDataPoint(0).getNumericalValues().isSparse()) {
            SparseVector[] vecs = new SparseVector[this.size()];
            for (int i = 0; i < this.size(); ++i) {
                Vec row = this.getDataPoint(i).getNumericalValues();
                vecs[i] = new SparseVector(row);
            }
            return new SparseMatrix(vecs);
        }
        DenseMatrix matrix = new DenseMatrix(this.size(), this.getNumNumericalVars());
        for (int i = 0; i < this.size(); ++i) {
            Vec row = this.getDataPoint(i).getNumericalValues();
            for (int j = 0; j < row.length(); ++j) {
                matrix.set(i, j, row.get(j));
            }
        }
        return matrix;
    }

    public Matrix getDataMatrixView() {
        return new MatrixOfVecs(this.getDataVectors());
    }

    public int getNumFeatures() {
        return this.getNumCategoricalVars() + this.getNumNumericalVars();
    }

    public abstract DataSet<Type> shallowClone();

    public abstract DataSet<Type> emptyClone();

    public DataSet getTwiceShallowClone() {
        DataSet<Type> clone = this.shallowClone();
        for (int i = 0; i < clone.size(); ++i) {
            DataPoint d = this.getDataPoint(i);
            DataPoint sd = new DataPoint(d.getNumericalValues(), d.getCategoricalValues(), d.getCategoricalData());
            clone.setDataPoint(i, sd);
        }
        return clone;
    }

    public OnLineStatistics getSparsityStats() {
        return this.datapoints.getSparsityStats();
    }

    public void setWeight(int i, double w) {
        if (i >= this.size() || i < 0) {
            throw new IndexOutOfBoundsException("Dataset has only " + this.size() + " members, can't access index " + i);
        }
        if (Double.isNaN(w) || Double.isInfinite(w) || w < 0.0) {
            throw new ArithmeticException("Invalid weight assignment of  " + w);
        }
        if (w == 1.0 && this.weights == null) {
            return;
        }
        if (this.weights == null) {
            this.weights = new double[this.size()];
            Arrays.fill(this.weights, 1.0);
        }
        if (this.weights.length <= i) {
            this.weights = Arrays.copyOfRange(this.weights, 0, Math.max(this.weights.length * 2, i + 1));
        }
        this.weights[i] = w;
    }

    public double getWeight(int i) {
        if (i >= this.size() || i < 0) {
            throw new IndexOutOfBoundsException("Dataset has only " + this.size() + " members, can't access index " + i);
        }
        if (this.weights == null) {
            return 1.0;
        }
        if (this.weights.length <= i) {
            return 1.0;
        }
        return this.weights[i];
    }

    public Vec getDataWeights() {
        int N = this.size();
        if (N == 0) {
            return new DenseVector(0);
        }
        double weight = this.getWeight(0);
        double[] weights_copy = null;
        for (int i = 1; i < N; ++i) {
            double w_i = this.getWeight(i);
            if (weights_copy == null && weight == w_i) continue;
            if (weights_copy == null) {
                weights_copy = new double[N];
                Arrays.fill(weights_copy, 0, i, weight);
            }
            weights_copy[i] = w_i;
        }
        if (weights_copy == null) {
            return new ConstantVector(weight, this.size());
        }
        return new DenseVector(weights_copy);
    }
}

