/*
 * Decompiled with CFR 0.152.
 */
package jsat.parameters;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.PriorityQueue;
import jsat.DataSet;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.ClassificationModelEvaluation;
import jsat.classifiers.Classifier;
import jsat.classifiers.WarmClassifier;
import jsat.distributions.Distribution;
import jsat.parameters.DoubleParameter;
import jsat.parameters.IntParameter;
import jsat.parameters.ModelSearch;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.RegressionModelEvaluation;
import jsat.regression.Regressor;
import jsat.regression.WarmRegressor;
import jsat.utils.DoubleList;
import jsat.utils.concurrent.ParallelUtils;

public class GridSearch
extends ModelSearch {
    private static final long serialVersionUID = -1987196172499143753L;
    private List<List<Double>> searchValues;
    private boolean useWarmStarts = true;

    public GridSearch(Regressor baseRegressor, int folds) {
        super(baseRegressor, folds);
        this.searchValues = new ArrayList<List<Double>>();
    }

    public GridSearch(Classifier baseClassifier, int folds) {
        super(baseClassifier, folds);
        this.searchValues = new ArrayList<List<Double>>();
    }

    public GridSearch(GridSearch toCopy) {
        super(toCopy);
        this.useWarmStarts = toCopy.useWarmStarts;
        if (toCopy.searchValues != null) {
            this.searchValues = new ArrayList<List<Double>>();
            for (List<Double> ld : toCopy.searchValues) {
                DoubleList newVals = new DoubleList(ld);
                this.searchValues.add(newVals);
            }
        }
    }

    public int autoAddParameters(DataSet data) {
        return this.autoAddParameters(data, 10);
    }

    public int autoAddParameters(DataSet data, int paramsEach) {
        Parameterized obj = this.baseClassifier != null ? (Parameterized)((Object)this.baseClassifier) : (Parameterized)((Object)this.baseRegressor);
        int totalParms = 0;
        for (Parameter param : obj.getParameters()) {
            Distribution dist;
            if (param instanceof DoubleParameter) {
                dist = ((DoubleParameter)param).getGuess(data);
                if (dist == null) continue;
                ++totalParms;
                continue;
            }
            if (!(param instanceof IntParameter) || (dist = ((IntParameter)param).getGuess(data)) == null) continue;
            ++totalParms;
        }
        if (totalParms < 1) {
            return 0;
        }
        double[] quantiles = new double[paramsEach];
        for (int i = 0; i < quantiles.length; ++i) {
            quantiles[i] = ((double)i + 1.0) / ((double)paramsEach + 1.0);
        }
        for (Parameter param : obj.getParameters()) {
            int i;
            Object[] vals;
            Distribution dist;
            if (param instanceof DoubleParameter) {
                dist = ((DoubleParameter)param).getGuess(data);
                if (dist == null) continue;
                vals = new double[paramsEach];
                for (i = 0; i < vals.length; ++i) {
                    vals[i] = dist.invCdf(quantiles[i]);
                }
                this.addParameter((DoubleParameter)param, (double[])vals);
                continue;
            }
            if (!(param instanceof IntParameter) || (dist = ((IntParameter)param).getGuess(data)) == null) continue;
            vals = new int[paramsEach];
            for (i = 0; i < vals.length; ++i) {
                vals[i] = (int)Math.round(dist.invCdf(quantiles[i]));
            }
            this.addParameter((IntParameter)param, (int[])vals);
        }
        return totalParms;
    }

    public void setUseWarmStarts(boolean useWarmStarts) {
        this.useWarmStarts = useWarmStarts;
    }

    public boolean isUseWarmStarts() {
        return this.useWarmStarts;
    }

    public void addParameter(DoubleParameter param, double ... initialSearchValues) {
        if (param == null) {
            throw new IllegalArgumentException("null not allowed for parameter");
        }
        this.searchParams.add(param);
        DoubleList dl = new DoubleList(initialSearchValues.length);
        for (double d : initialSearchValues) {
            dl.add(d);
        }
        Arrays.sort(dl.getBackingArray());
        if (param.isWarmParameter() && !param.preferredLowToHigh()) {
            Collections.reverse(dl);
        }
        if (param.isWarmParameter()) {
            this.searchValues.add(0, dl);
        } else {
            this.searchValues.add(dl);
        }
    }

    public void addParameter(String name, double ... initialSearchValues) {
        Parameter param = this.getParameterByName(name);
        if (!(param instanceof DoubleParameter)) {
            throw new IllegalArgumentException("Parameter " + name + " is not for double values");
        }
        this.addParameter((DoubleParameter)param, initialSearchValues);
    }

    public void addParameter(IntParameter param, int ... initialSearchValues) {
        this.searchParams.add(param);
        DoubleList dl = new DoubleList(initialSearchValues.length);
        int[] nArray = initialSearchValues;
        int n = nArray.length;
        for (int i = 0; i < n; ++i) {
            double d = nArray[i];
            dl.add(d);
        }
        Arrays.sort(dl.getBackingArray());
        if (param.isWarmParameter() && !param.preferredLowToHigh()) {
            Collections.reverse(dl);
        }
        if (param.isWarmParameter()) {
            this.searchValues.add(0, dl);
        } else {
            this.searchValues.add(dl);
        }
    }

    public void addParameter(String name, int ... initialSearchValues) {
        Parameter param = this.getParameterByName(name);
        if (!(param instanceof IntParameter)) {
            throw new IllegalArgumentException("Parameter " + name + " is not for int values");
        }
        this.addParameter((IntParameter)param, initialSearchValues);
    }

    @Override
    public void train(RegressionDataSet dataSet, boolean parallel) {
        boolean considerWarm;
        ArrayList<RegressionDataSet> trainCombinations;
        List<RegressionDataSet> preFolded;
        PriorityQueue bestModels = new PriorityQueue(this.folds, (t, t1) -> {
            double v0 = t.getScoreStats(this.regressionTargetScore).getMean();
            double v1 = t1.getScoreStats(this.regressionTargetScore).getMean();
            int order = this.regressionTargetScore.lowerIsBetter() ? 1 : -1;
            return order * Double.compare(v0, v1);
        });
        int[] setTo = new int[this.searchParams.size()];
        ArrayList<Regressor> paramsToEval = new ArrayList<Regressor>();
        do {
            this.setParameters(setTo);
            paramsToEval.add(this.baseRegressor.clone());
        } while (!this.incrementCombination(setTo));
        if (this.reuseSameCVFolds) {
            preFolded = dataSet.cvSet(this.folds);
            trainCombinations = new ArrayList<RegressionDataSet>(preFolded.size());
            for (int i = 0; i < preFolded.size(); ++i) {
                trainCombinations.add(RegressionDataSet.comineAllBut(preFolded, i));
            }
        } else {
            preFolded = null;
            trainCombinations = null;
        }
        boolean bl = considerWarm = this.useWarmStarts && this.baseRegressor instanceof WarmRegressor;
        if (considerWarm && (!((WarmRegressor)this.baseRegressor).warmFromSameDataOnly() || this.reuseSameCVFolds)) {
            ParallelUtils.run(parallel && this.trainModelsInParallel, paramsToEval.size(), (start, end) -> {
                List subSet = paramsToEval.subList(start, end);
                Regressor[] prevModels = null;
                for (Regressor r : subSet) {
                    RegressionModelEvaluation rme = new RegressionModelEvaluation(r, dataSet, !this.trainModelsInParallel && parallel);
                    rme.setKeepModels(true);
                    rme.setWarmModels(prevModels);
                    rme.addScorer(this.regressionTargetScore.clone());
                    if (this.reuseSameCVFolds) {
                        rme.evaluateCrossValidation(preFolded, trainCombinations);
                    } else {
                        rme.evaluateCrossValidation(this.folds);
                    }
                    prevModels = rme.getKeptModels();
                    PriorityQueue priorityQueue = bestModels;
                    synchronized (priorityQueue) {
                        bestModels.add(rme);
                    }
                }
            });
        } else {
            ParallelUtils.run(parallel && this.trainModelsInParallel, paramsToEval.size(), indx -> {
                Regressor r = (Regressor)paramsToEval.get(indx);
                RegressionModelEvaluation rme = new RegressionModelEvaluation(r, dataSet, !this.trainModelsInParallel && parallel);
                rme.addScorer(this.regressionTargetScore.clone());
                if (this.reuseSameCVFolds) {
                    rme.evaluateCrossValidation(preFolded, trainCombinations);
                } else {
                    rme.evaluateCrossValidation(this.folds);
                }
                PriorityQueue priorityQueue = bestModels;
                synchronized (priorityQueue) {
                    bestModels.add(rme);
                }
            });
        }
        Regressor bestRegressor = ((RegressionModelEvaluation)bestModels.peek()).getRegressor();
        if (this.trainFinalModel) {
            if (this.useWarmStarts && bestRegressor instanceof WarmRegressor && !((WarmRegressor)bestRegressor).warmFromSameDataOnly()) {
                WarmRegressor wr = (WarmRegressor)bestRegressor;
                wr.train(dataSet, (Regressor)wr.clone(), parallel);
            } else {
                bestRegressor.train(dataSet, parallel);
            }
        }
        this.trainedRegressor = bestRegressor;
    }

    @Override
    public void train(ClassificationDataSet dataSet, boolean parallel) {
        boolean considerWarm;
        ArrayList<ClassificationDataSet> trainCombinations;
        List<ClassificationDataSet> preFolded;
        PriorityQueue bestModels = new PriorityQueue(this.folds, (t, t1) -> {
            double v0 = t.getScoreStats(this.classificationTargetScore).getMean();
            double v1 = t1.getScoreStats(this.classificationTargetScore).getMean();
            int order = this.classificationTargetScore.lowerIsBetter() ? 1 : -1;
            return order * Double.compare(v0, v1);
        });
        int[] setTo = new int[this.searchParams.size()];
        ArrayList<Classifier> paramsToEval = new ArrayList<Classifier>();
        do {
            this.setParameters(setTo);
            paramsToEval.add(this.baseClassifier.clone());
        } while (!this.incrementCombination(setTo));
        if (this.reuseSameCVFolds) {
            preFolded = dataSet.cvSet(this.folds);
            trainCombinations = new ArrayList<ClassificationDataSet>(preFolded.size());
            for (int i = 0; i < preFolded.size(); ++i) {
                trainCombinations.add(ClassificationDataSet.comineAllBut(preFolded, i));
            }
        } else {
            preFolded = null;
            trainCombinations = null;
        }
        boolean bl = considerWarm = this.useWarmStarts && this.baseClassifier instanceof WarmClassifier;
        if (considerWarm && (!((WarmClassifier)this.baseClassifier).warmFromSameDataOnly() || this.reuseSameCVFolds)) {
            ParallelUtils.run(parallel && this.trainModelsInParallel, paramsToEval.size(), (start, end) -> {
                List subSet = paramsToEval.subList(start, end);
                Classifier[] prevModels = null;
                for (Classifier r : subSet) {
                    ClassificationModelEvaluation cme = new ClassificationModelEvaluation(r, dataSet, !this.trainModelsInParallel && parallel);
                    cme.setKeepModels(true);
                    cme.setWarmModels(prevModels);
                    cme.addScorer(this.classificationTargetScore.clone());
                    if (this.reuseSameCVFolds) {
                        cme.evaluateCrossValidation(preFolded, trainCombinations);
                    } else {
                        cme.evaluateCrossValidation(this.folds);
                    }
                    prevModels = cme.getKeptModels();
                    PriorityQueue priorityQueue = bestModels;
                    synchronized (priorityQueue) {
                        bestModels.add(cme);
                    }
                }
            });
        } else {
            ParallelUtils.run(parallel && this.trainModelsInParallel, paramsToEval.size(), indx -> {
                Classifier toTrain = (Classifier)paramsToEval.get(indx);
                ClassificationModelEvaluation cme = new ClassificationModelEvaluation(toTrain, dataSet, !this.trainModelsInParallel && parallel);
                cme.addScorer(this.classificationTargetScore.clone());
                if (this.reuseSameCVFolds) {
                    cme.evaluateCrossValidation(preFolded, trainCombinations);
                } else {
                    cme.evaluateCrossValidation(this.folds);
                }
                PriorityQueue priorityQueue = bestModels;
                synchronized (priorityQueue) {
                    bestModels.add(cme);
                }
            });
        }
        Classifier bestClassifier = ((ClassificationModelEvaluation)bestModels.peek()).getClassifier();
        if (this.trainFinalModel) {
            if (this.useWarmStarts && bestClassifier instanceof WarmClassifier && !((WarmClassifier)bestClassifier).warmFromSameDataOnly()) {
                WarmClassifier wc = (WarmClassifier)bestClassifier;
                wc.train(dataSet, (Classifier)wc.clone(), parallel);
            } else {
                bestClassifier.train(dataSet, parallel);
            }
        }
        this.trainedClassifier = bestClassifier;
    }

    @Override
    public GridSearch clone() {
        return new GridSearch(this);
    }

    private boolean incrementCombination(int[] setTo) {
        setTo[0] = setTo[0] + 1;
        int carryPos = 0;
        while (carryPos < setTo.length - 1 && setTo[carryPos] >= this.searchValues.get(carryPos).size()) {
            setTo[carryPos] = 0;
            int n = ++carryPos;
            setTo[n] = setTo[n] + 1;
        }
        return setTo[setTo.length - 1] >= this.searchValues.get(setTo.length - 1).size();
    }

    private void setParameters(int[] setTo) {
        for (int i = 0; i < setTo.length; ++i) {
            Parameter param = (Parameter)this.searchParams.get(i);
            if (param instanceof DoubleParameter) {
                ((DoubleParameter)param).setValue(this.searchValues.get(i).get(setTo[i]));
                continue;
            }
            if (!(param instanceof IntParameter)) continue;
            ((IntParameter)param).setValue(this.searchValues.get(i).get(setTo[i]).intValue());
        }
    }
}

