/*
 * Decompiled with CFR 0.152.
 */
package jsat.datatransform;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import jsat.DataSet;
import jsat.classifiers.DataPoint;
import jsat.datatransform.InPlaceTransform;
import jsat.linear.IndexValue;
import jsat.linear.Vec;
import jsat.math.OnLineStatistics;
import jsat.utils.DoubleList;
import jsat.utils.IndexTable;

public class Imputer
implements InPlaceTransform {
    private NumericImputionMode mode;
    protected int[] cat_imputs;
    protected double[] numeric_imputs;

    public Imputer(DataSet<?> data) {
        this(data, NumericImputionMode.MEAN);
    }

    public Imputer(DataSet<?> data, NumericImputionMode mode) {
        this.mode = mode;
        this.fit(data);
    }

    public Imputer(Imputer toCopy) {
        this.mode = toCopy.mode;
        if (toCopy.cat_imputs != null) {
            this.cat_imputs = Arrays.copyOf(toCopy.cat_imputs, toCopy.cat_imputs.length);
        }
        if (toCopy.numeric_imputs != null) {
            this.numeric_imputs = Arrays.copyOf(toCopy.numeric_imputs, toCopy.numeric_imputs.length);
        }
    }

    @Override
    public void fit(DataSet d) {
        int i;
        this.numeric_imputs = new double[d.getNumNumericalVars()];
        this.cat_imputs = new int[d.getNumCategoricalVars()];
        ArrayList<DoubleList> columnCounts = null;
        ArrayList<DoubleList> columnWeights = null;
        double[] colSoW = null;
        switch (this.mode) {
            case MEAN: {
                OnLineStatistics[] stats = d.getOnlineColumnStats(true);
                for (i = 0; i < stats.length; ++i) {
                    this.numeric_imputs[i] = stats[i].getMean();
                }
                break;
            }
            case MEDIAN: {
                columnCounts = new ArrayList<DoubleList>(d.getNumNumericalVars());
                columnWeights = new ArrayList<DoubleList>(d.getNumNumericalVars());
                colSoW = new double[d.getNumNumericalVars()];
                for (i = 0; i < d.getNumNumericalVars(); ++i) {
                    columnCounts.add(new DoubleList(d.size()));
                    columnWeights.add(new DoubleList(d.size()));
                }
                break;
            }
        }
        double[][] cat_counts = new double[d.getNumCategoricalVars()][];
        for (i = 0; i < cat_counts.length; ++i) {
            cat_counts[i] = new double[d.getCategories()[i].getNumOfCategories()];
        }
        for (int sample = 0; sample < d.size(); ++sample) {
            DataPoint dp = d.getDataPoint(sample);
            double weights = d.getWeight(sample);
            int[] cats = dp.getCategoricalValues();
            for (int i2 = 0; i2 < cats.length; ++i2) {
                if (cats[i2] < 0) continue;
                double[] dArray = cat_counts[i2];
                int n = cats[i2];
                dArray[n] = dArray[n] + weights;
            }
            Vec numeric = dp.getNumericalValues();
            if (this.mode != NumericImputionMode.MEDIAN) continue;
            for (IndexValue iv : numeric) {
                if (Double.isNaN(iv.getValue())) continue;
                ((List)columnCounts.get(iv.getIndex())).add(iv.getValue());
                ((List)columnWeights.get(iv.getIndex())).add(weights);
                int n = iv.getIndex();
                colSoW[n] = colSoW[n] + weights;
            }
        }
        if (this.mode == NumericImputionMode.MEDIAN) {
            IndexTable it = new IndexTable(d.getNumNumericalVars());
            for (int col = 0; col < d.getNumNumericalVars(); ++col) {
                int indx;
                List colVal = (List)columnCounts.get(col);
                List colWeight = (List)columnWeights.get(col);
                it.reset();
                it.sort(colVal);
                double goal = colSoW[col] / 2.0;
                double lastSeen = 0.0;
                double curWeight = 0.0;
                for (int i3 = 0; i3 < it.length() && curWeight < goal; curWeight += ((Double)colWeight.get(indx)).doubleValue(), ++i3) {
                    indx = it.index(i3);
                    lastSeen = (Double)colVal.get(indx);
                }
                this.numeric_imputs[col] = lastSeen;
            }
        }
        for (int col = 0; col < cat_counts.length; ++col) {
            int col_mode = 0;
            for (int j = 1; j < cat_counts[col].length; ++j) {
                if (!(cat_counts[col][j] > cat_counts[col][col_mode])) continue;
                col_mode = j;
            }
            this.cat_imputs[col] = col_mode;
        }
    }

    @Override
    public void mutableTransform(DataPoint dp) {
        Vec vec = dp.getNumericalValues();
        for (IndexValue iv : vec) {
            if (!Double.isNaN(iv.getValue())) continue;
            vec.set(iv.getIndex(), this.numeric_imputs[iv.getIndex()]);
        }
        int[] cats = dp.getCategoricalValues();
        for (int i = 0; i < cats.length; ++i) {
            if (cats[i] >= 0) continue;
            cats[i] = this.cat_imputs[i];
        }
    }

    @Override
    public boolean mutatesNominal() {
        return true;
    }

    @Override
    public DataPoint transform(DataPoint dp) {
        DataPoint toRet = dp.clone();
        this.mutableTransform(toRet);
        return toRet;
    }

    @Override
    public Imputer clone() {
        return new Imputer(this);
    }

    public static enum NumericImputionMode {
        MEAN,
        MEDIAN;

    }
}

