/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.calibration;

import java.util.Collections;
import java.util.List;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.DataPointPair;
import jsat.classifiers.calibration.BinaryScoreClassifier;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;

public abstract class BinaryCalibration
implements Classifier,
Parameterized {
    private static final long serialVersionUID = 2356311701854978890L;
    @Parameter.ParameterHolder
    protected BinaryScoreClassifier base;
    protected int folds = 3;
    protected double holdOut = 0.3;
    protected CalibrationMode mode;

    public BinaryCalibration(BinaryScoreClassifier base, CalibrationMode mode) {
        this.base = base;
        this.setCalibrationMode(mode);
    }

    @Override
    public void train(ClassificationDataSet dataSet, boolean parallel) {
        double[] deci = new double[dataSet.size()];
        boolean[] label = new boolean[deci.length];
        int len = label.length;
        if (this.mode == CalibrationMode.CV) {
            List<ClassificationDataSet> foldList = dataSet.cvSet(this.folds);
            int pos = 0;
            for (int i = 0; i < foldList.size(); ++i) {
                ClassificationDataSet test = (ClassificationDataSet)foldList.get(i);
                ClassificationDataSet train = ClassificationDataSet.comineAllBut(foldList, i);
                this.base.train(train, parallel);
                for (int j = 0; j < test.size(); ++j) {
                    deci[pos] = this.base.getScore(test.getDataPoint(j));
                    label[pos] = test.getDataPointCategory(j) == 1;
                    ++pos;
                }
            }
            this.base.train(dataSet, parallel);
        } else if (this.mode == CalibrationMode.HOLD_OUT) {
            List<DataPointPair<Integer>> wholeSet = dataSet.getAsDPPList();
            Collections.shuffle(wholeSet);
            int splitMark = (int)((double)wholeSet.size() * (1.0 - this.holdOut));
            ClassificationDataSet train = new ClassificationDataSet(wholeSet.subList(0, splitMark), dataSet.getPredicting());
            ClassificationDataSet test = new ClassificationDataSet(wholeSet.subList(splitMark, wholeSet.size()), dataSet.getPredicting());
            this.base.train(train, parallel);
            for (int i = 0; i < test.size(); ++i) {
                deci[i] = this.base.getScore(test.getDataPoint(i));
                label[i] = test.getDataPointCategory(i) == 1;
            }
            len = test.size();
            this.base.train(dataSet, parallel);
        } else {
            this.base.train(dataSet, parallel);
            for (int i = 0; i < len; ++i) {
                DataPoint dp = dataSet.getDataPoint(i);
                deci[i] = this.base.getScore(dp);
                label[i] = dataSet.getDataPointCategory(i) == 1;
            }
        }
        this.calibrate(label, deci, len);
    }

    protected abstract void calibrate(boolean[] var1, double[] var2, int var3);

    public void setCalibrationFolds(int folds) {
        if (folds < 1) {
            throw new IllegalArgumentException("Folds must be a positive value, not " + folds);
        }
        this.folds = folds;
    }

    public int getCalibrationFolds() {
        return this.folds;
    }

    public void setCalibrationHoldOut(double holdOut) {
        if (Double.isNaN(holdOut) || holdOut <= 0.0 || holdOut >= 1.0) {
            throw new IllegalArgumentException("HoldOut must be in (0, 1), not " + holdOut);
        }
        this.holdOut = holdOut;
    }

    public double getCalibrationHoldOut() {
        return this.holdOut;
    }

    public void setCalibrationMode(CalibrationMode mode) {
        this.mode = mode;
    }

    public CalibrationMode getCalibrationMode() {
        return this.mode;
    }

    @Override
    public abstract BinaryCalibration clone();

    public static enum CalibrationMode {
        NAIVE,
        CV,
        HOLD_OUT;

    }
}

