/*
 * Decompiled with CFR 0.152.
 */
package jsat.datatransform.featureselection;

import java.util.Iterator;
import java.util.Random;
import java.util.Set;
import jsat.DataSet;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.datatransform.DataTransform;
import jsat.datatransform.RemoveAttributeTransform;
import jsat.datatransform.featureselection.SBS;
import jsat.datatransform.featureselection.SFS;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.utils.IntSet;
import jsat.utils.ListUtils;
import jsat.utils.random.RandomUtil;

public class LRS
implements DataTransform {
    private static final long serialVersionUID = 3065300352046535656L;
    private RemoveAttributeTransform finalTransform;
    private Set<Integer> catSelected;
    private Set<Integer> numSelected;
    private int L;
    private int R;
    private Object evaluater;
    private int folds;

    private LRS(LRS toClone) {
        this.L = toClone.L;
        this.R = toClone.R;
        this.folds = toClone.folds;
        this.evaluater = toClone.evaluater;
        if (toClone.catSelected != null) {
            this.finalTransform = toClone.finalTransform.clone();
            this.catSelected = new IntSet(toClone.catSelected);
            this.numSelected = new IntSet(toClone.numSelected);
        }
    }

    public LRS(int L, int R, Classifier evaluater, int folds) {
        this.setFeaturesToAdd(L);
        this.setFeaturesToRemove(R);
        this.setFolds(folds);
        this.setEvaluator(evaluater);
    }

    public LRS(int L, int R, ClassificationDataSet cds, Classifier evaluater, int folds) {
        this.search(cds, L, R, evaluater, folds);
    }

    public LRS(int L, int R, Regressor evaluater, int folds) {
        this.setFeaturesToAdd(L);
        this.setFeaturesToRemove(R);
        this.setFolds(folds);
        this.setEvaluator(evaluater);
    }

    public LRS(int L, int R, RegressionDataSet rds, Regressor evaluater, int folds) {
        this(L, R, evaluater, folds);
        this.search(rds, L, R, evaluater, folds);
    }

    @Override
    public DataPoint transform(DataPoint dp) {
        return this.finalTransform.transform(dp);
    }

    @Override
    public LRS clone() {
        return new LRS(this);
    }

    public Set<Integer> getSelectedCategorical() {
        return new IntSet(this.catSelected);
    }

    public Set<Integer> getSelectedNumerical() {
        return new IntSet(this.numSelected);
    }

    @Override
    public void fit(DataSet data) {
        this.search(data, this.L, this.R, this.evaluater, this.folds);
    }

    private void search(DataSet cds, int L, int R, Object evaluater, int folds) {
        int nF = cds.getNumFeatures();
        int nCat = cds.getNumCategoricalVars();
        this.catSelected = new IntSet(nCat);
        this.numSelected = new IntSet(nF - nCat);
        IntSet catToRemove = new IntSet(nCat);
        IntSet numToRemove = new IntSet(nF - nCat);
        IntSet available = new IntSet(nF);
        ListUtils.addRange(available, 0, nF, 1);
        Random rand = RandomUtil.getRandom();
        double[] pBestScore = new double[]{Double.POSITIVE_INFINITY};
        if (L > R) {
            ListUtils.addRange(catToRemove, 0, nCat, 1);
            ListUtils.addRange(numToRemove, 0, nF - nCat, 1);
            for (int i = 0; i < L; ++i) {
                SFS.SFSSelectFeature(available, cds, catToRemove, numToRemove, this.catSelected, this.numSelected, evaluater, folds, rand, pBestScore, L);
            }
            available.clear();
            available.addAll(this.catSelected);
            for (int i : this.numSelected) {
                available.add(Integer.valueOf(i + nCat));
            }
            for (int i = 0; i < R; ++i) {
                SBS.SBSRemoveFeature(available, cds, catToRemove, numToRemove, this.catSelected, this.numSelected, evaluater, folds, rand, L - R, pBestScore, 0.0);
            }
        } else if (L < R) {
            ListUtils.addRange(this.catSelected, 0, nCat, 1);
            ListUtils.addRange(this.numSelected, 0, nF - nCat, 1);
            for (int i = 0; i < R; ++i) {
                SBS.SBSRemoveFeature(available, cds, catToRemove, numToRemove, this.catSelected, this.numSelected, evaluater, folds, rand, nF - R, pBestScore, 0.0);
            }
            available.clear();
            available.addAll(catToRemove);
            Iterator i = numToRemove.iterator();
            while (i.hasNext()) {
                int i2 = (Integer)i.next();
                available.add(Integer.valueOf(i2 + nCat));
            }
            for (int i3 = 0; i3 < L; ++i3) {
                SFS.SFSSelectFeature(available, cds, catToRemove, numToRemove, this.catSelected, this.numSelected, evaluater, folds, rand, pBestScore, R - L);
            }
        }
        this.finalTransform = new RemoveAttributeTransform(cds, catToRemove, numToRemove);
    }

    public void setFeaturesToAdd(int featuresToAdd) {
        if (featuresToAdd < 1) {
            throw new IllegalArgumentException("Number of features to add must be positive, not " + featuresToAdd);
        }
        this.L = featuresToAdd;
    }

    public int getFeaturesToAdd() {
        return this.L;
    }

    public void setFeaturesToRemove(int featuresToRemove) {
        if (featuresToRemove < 1) {
            throw new IllegalArgumentException("Number of features to remove must be positive, not " + featuresToRemove);
        }
        this.R = featuresToRemove;
    }

    public int getFeaturesToRemove() {
        return this.R;
    }

    public void setFolds(int folds) {
        if (folds <= 0) {
            throw new IllegalArgumentException("Number of CV folds must be positive, not " + folds);
        }
        this.folds = folds;
    }

    public int getFolds() {
        return this.folds;
    }

    private void setEvaluator(Object evaluator) {
        this.evaluater = evaluator;
    }
}

