/*
 * Decompiled with CFR 0.152.
 */
package jsat.regression;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import jsat.DataSet;
import jsat.DataStore;
import jsat.RowMajorStore;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.DataPoint;
import jsat.classifiers.DataPointPair;
import jsat.linear.DenseVector;
import jsat.linear.IndexValue;
import jsat.linear.SparseVector;
import jsat.linear.Vec;
import jsat.utils.DoubleList;

public class RegressionDataSet
extends DataSet<RegressionDataSet> {
    protected DoubleList targets;
    private static final int[] emptyInt = new int[0];

    public RegressionDataSet(int numerical, CategoricalData[] categories) {
        super(numerical, categories);
        this.targets = new DoubleList();
    }

    public RegressionDataSet(DataStore datapoints, List<Double> targets) {
        super(datapoints);
        this.targets = new DoubleList(targets);
    }

    public RegressionDataSet(List<DataPoint> data, int predicting) {
        super(data.get(0).numNumericalValues() - 1, data.get(0).getCategoricalData());
        DataPoint tmp = data.get(0);
        this.categories = new CategoricalData[tmp.numCategoricalValues()];
        System.arraycopy(tmp.getCategoricalData(), 0, this.categories, 0, this.categories.length);
        this.targets = new DoubleList(data.size());
        for (DataPoint dp : data) {
            Vec origV = dp.getNumericalValues();
            double target = 0.0;
            Vec newVec = origV.isSparse() ? new SparseVector(origV.length() - 1, origV.nnz()) : new DenseVector(origV.length() - 1);
            for (IndexValue iv : origV) {
                if (iv.getIndex() < predicting) {
                    newVec.set(iv.getIndex(), iv.getValue());
                    continue;
                }
                if (iv.getIndex() == predicting) {
                    target = iv.getValue();
                    continue;
                }
                newVec.set(iv.getIndex() - 1, iv.getValue());
            }
            DataPoint newDp = new DataPoint(newVec, dp.getCategoricalValues(), this.categories);
            this.datapoints.addDataPoint(newDp);
            this.targets.add(target);
        }
    }

    public RegressionDataSet(List<DataPointPair<Double>> list) {
        super(list.get(0).getDataPoint().numNumericalValues(), CategoricalData.copyOf(list.get(0).getDataPoint().getCategoricalData()));
        this.datapoints = new RowMajorStore(this.numNumerVals, this.categories);
        this.targets = new DoubleList();
        for (DataPointPair<Double> dpp : list) {
            this.datapoints.addDataPoint(dpp.getDataPoint());
            this.targets.add(dpp.getPair());
        }
    }

    private RegressionDataSet() {
        super(new RowMajorStore(1, new CategoricalData[0]));
    }

    public static RegressionDataSet comineAllBut(List<RegressionDataSet> list, int exception) {
        int numer = list.get(exception).getNumNumericalVars();
        CategoricalData[] categories = list.get(exception).getCategories();
        RegressionDataSet rds = new RegressionDataSet(numer, categories);
        for (int i = 0; i < list.size(); ++i) {
            if (i == exception) continue;
            for (int j = 0; j < list.get(i).size(); ++j) {
                rds.addDataPoint(list.get(i).getDataPoint(j), list.get(i).getTargetValue(j));
            }
        }
        return rds;
    }

    public void addDataPoint(Vec numerical, double val) {
        this.addDataPoint(numerical, emptyInt, val);
    }

    public void addDataPoint(Vec numerical, int[] categories, double val) {
        if (numerical.length() != this.numNumerVals) {
            throw new RuntimeException("Data point does not contain enough numerical data points");
        }
        if (this.categories.length != categories.length) {
            throw new RuntimeException("Data point does not contain enough categorical data points");
        }
        for (int i = 0; i < categories.length; ++i) {
            if (this.categories[i].isValidCategory(categories[i]) || categories[i] < 0) continue;
            throw new RuntimeException("Categoriy value given is invalid");
        }
        DataPoint dp = new DataPoint(numerical, categories, this.categories);
        this.addDataPoint(dp, val);
    }

    public void addDataPoint(DataPoint dp, double val) {
        this.addDataPoint(dp, val, 1.0);
    }

    public void addDataPoint(DataPoint dp, double val, double weight) {
        if (dp.numNumericalValues() != this.getNumNumericalVars() || dp.numCategoricalValues() != this.getNumCategoricalVars()) {
            throw new RuntimeException("The added data point does not match the number of values and categories for the data set");
        }
        if (Double.isInfinite(val) || Double.isNaN(val)) {
            throw new ArithmeticException("Unregressiable value " + val + " given for regression");
        }
        this.datapoints.addDataPoint(dp);
        this.targets.add(val);
        this.setWeight(this.size() - 1, weight);
    }

    public void addDataPointPair(DataPointPair<Double> pair) {
        this.addDataPoint(pair.getDataPoint(), (double)pair.getPair());
    }

    public DataPointPair<Double> getDataPointPair(int i) {
        return new DataPointPair<Double>(this.getDataPoint(i), this.targets.get(i));
    }

    public List<DataPointPair<Double>> getAsDPPList() {
        ArrayList<DataPointPair<Double>> list = new ArrayList<DataPointPair<Double>>(this.size());
        for (int i = 0; i < this.size(); ++i) {
            list.add(new DataPointPair<Double>(this.getDataPoint(i).clone(), this.targets.get(i)));
        }
        return list;
    }

    public List<DataPointPair<Double>> getDPPList() {
        ArrayList<DataPointPair<Double>> list = new ArrayList<DataPointPair<Double>>(this.size());
        for (int i = 0; i < this.size(); ++i) {
            list.add(this.getDataPointPair(i));
        }
        return list;
    }

    public void setTargetValue(int i, double val) {
        if (Double.isInfinite(val) || Double.isNaN(val)) {
            throw new ArithmeticException("Can not predict a " + val + " value");
        }
        this.targets.set(i, val);
    }

    @Override
    protected RegressionDataSet getSubset(List<Integer> indicies) {
        if (this.datapoints.rowMajor()) {
            RegressionDataSet newData = new RegressionDataSet(this.numNumerVals, this.categories);
            for (int i : indicies) {
                newData.addDataPoint(this.getDataPoint(i), this.getTargetValue(i));
            }
            return newData;
        }
        int new_n = indicies.size();
        HashMap<Integer, Integer> old_indx_to_new = new HashMap<Integer, Integer>(indicies.size());
        for (int new_i = 0; new_i < indicies.size(); ++new_i) {
            old_indx_to_new.put(indicies.get(new_i), new_i);
        }
        DataStore new_ds = this.datapoints.emptyClone();
        Iterator<DataPoint> data_iter = this.datapoints.getRowIter();
        DoubleList new_targets = new DoubleList();
        int orig_pos = 0;
        while (data_iter.hasNext()) {
            DataPoint dp = data_iter.next();
            if (old_indx_to_new.containsKey(orig_pos)) {
                DataPoint new_dp = new DataPoint(dp.getNumericalValues().clone(), Arrays.copyOf(dp.getCategoricalValues(), this.getNumCategoricalVars()), this.categories);
                new_ds.addDataPoint(new_dp);
                new_targets.add(this.getTargetValue(orig_pos));
            }
            ++orig_pos;
        }
        new_ds.finishAdding();
        return new RegressionDataSet(new_ds, new_targets);
    }

    public Vec getTargetValues() {
        DenseVector vals = new DenseVector(this.size());
        for (int i = 0; i < this.size(); ++i) {
            vals.set(i, this.targets.getD(i));
        }
        return vals;
    }

    public double getTargetValue(int i) {
        return this.targets.getD(i);
    }

    public static RegressionDataSet usingDPPList(List<DataPointPair<Double>> list) {
        return new RegressionDataSet(list);
    }

    public RegressionDataSet shallowClone() {
        RegressionDataSet clone = new RegressionDataSet(this.numNumerVals, this.categories);
        for (int i = 0; i < this.size(); ++i) {
            clone.addDataPointPair(this.getDataPointPair(i));
        }
        return clone;
    }

    public RegressionDataSet emptyClone() {
        return new RegressionDataSet(this.numNumerVals, this.categories);
    }

    @Override
    public RegressionDataSet getTwiceShallowClone() {
        return (RegressionDataSet)super.getTwiceShallowClone();
    }
}

