/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.boosting;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Semaphore;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.math.OnLineStatistics;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.utils.SystemInfo;
import jsat.utils.concurrent.ParallelUtils;

public class Bagging
implements Classifier,
Regressor,
Parameterized {
    private static final long serialVersionUID = -6566453570170428838L;
    private Classifier baseClassifier;
    private Regressor baseRegressor;
    private CategoricalData predicting;
    private int extraSamples;
    private int rounds;
    private boolean simultaniousTraining;
    private Random random;
    private List learners;
    public static final int DEFAULT_ROUNDS = 20;
    public static final int DEFAULT_EXTRA_SAMPLES = 0;
    public static final boolean DEFAULT_SIMULTANIOUS_TRAINING = true;

    public Bagging(Classifier baseClassifier) {
        this(baseClassifier, 0, true);
    }

    public Bagging(Classifier baseClassifier, int extraSamples, boolean simultaniousTraining) {
        this(baseClassifier, extraSamples, simultaniousTraining, 20, new Random(1L));
    }

    public Bagging(Classifier baseClassifier, int extraSamples, boolean simultaniousTraining, int rounds, Random random) {
        this(extraSamples, simultaniousTraining, rounds, random);
        this.baseClassifier = baseClassifier;
    }

    public Bagging(Regressor baseRegressor) {
        this(baseRegressor, 0, true);
    }

    public Bagging(Regressor baseRegressor, int extraSamples, boolean simultaniousTraining) {
        this(baseRegressor, extraSamples, simultaniousTraining, 20, new Random(1L));
    }

    public Bagging(Regressor baseRegressor, int extraSamples, boolean simultaniousTraining, int rounds, Random random) {
        this(extraSamples, simultaniousTraining, rounds, random);
        this.baseRegressor = baseRegressor;
    }

    private Bagging(int extraSamples, boolean simultaniousTraining, int rounds, Random random) {
        this.setExtraSamples(extraSamples);
        this.setSimultaniousTraining(simultaniousTraining);
        this.setRounds(rounds);
        this.random = random;
    }

    public void setExtraSamples(int i) {
        this.extraSamples = i;
    }

    public int getExtraSamples() {
        return this.extraSamples;
    }

    public void setRounds(int rounds) {
        if (rounds <= 0) {
            throw new ArithmeticException("Must train a positive number of learners");
        }
        this.rounds = rounds;
    }

    public int getRounds() {
        return this.rounds;
    }

    public void setSimultaniousTraining(boolean simultaniousTraining) {
        this.simultaniousTraining = simultaniousTraining;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        if (this.baseClassifier == null) {
            throw new RuntimeException("Bagging instance created for regression, not classification");
        }
        if (this.learners == null || this.learners.isEmpty()) {
            throw new RuntimeException("Classifier has not yet been trained");
        }
        CategoricalResults totalResult = new CategoricalResults(this.predicting.getNumOfCategories());
        for (int i = 0; i < this.learners.size(); ++i) {
            CategoricalResults result = ((Classifier)this.learners.get(i)).classify(data);
            totalResult.incProb(result.mostLikely(), 1.0);
        }
        totalResult.normalize();
        return totalResult;
    }

    @Override
    public void train(ClassificationDataSet dataSet, boolean parallel) {
        this.predicting = dataSet.getPredicting();
        this.learners = new ArrayList(this.rounds);
        Semaphore waitForThread = new Semaphore(SystemInfo.LogicalCores);
        CountDownLatch waitForFinish = new CountDownLatch(this.rounds);
        List synchronizedLearners = Collections.synchronizedList(this.learners);
        int[] sampleCounts = new int[dataSet.size()];
        ExecutorService threadPool = ParallelUtils.getNewExecutor(parallel);
        for (int i = 0; i < this.rounds; ++i) {
            Bagging.sampleWithReplacement(sampleCounts, sampleCounts.length + this.extraSamples, this.random);
            ClassificationDataSet sampleSet = Bagging.getSampledDataSet(dataSet, sampleCounts);
            Classifier learner = this.baseClassifier.clone();
            if (this.simultaniousTraining && parallel) {
                try {
                    waitForThread.acquire();
                    threadPool.submit(() -> {
                        learner.train(sampleSet);
                        synchronizedLearners.add(learner);
                        waitForThread.release();
                        waitForFinish.countDown();
                    });
                }
                catch (InterruptedException ex) {
                    Logger.getLogger(Bagging.class.getName()).log(Level.SEVERE, null, ex);
                    System.err.println(ex.getMessage());
                }
                continue;
            }
            learner.train(sampleSet, parallel);
            this.learners.add(learner);
        }
        if (this.simultaniousTraining && parallel) {
            try {
                waitForFinish.await();
            }
            catch (InterruptedException ex) {
                Logger.getLogger(Bagging.class.getName()).log(Level.SEVERE, null, ex);
                System.err.println(ex.getMessage());
            }
        }
        threadPool.shutdownNow();
    }

    public static ClassificationDataSet getSampledDataSet(ClassificationDataSet dataSet, int[] sampledCounts) {
        ClassificationDataSet destination = new ClassificationDataSet(dataSet.getNumNumericalVars(), dataSet.getCategories(), dataSet.getPredicting());
        for (int i = 0; i < sampledCounts.length; ++i) {
            for (int j = 0; j < sampledCounts[i]; ++j) {
                DataPoint dp = dataSet.getDataPoint(i);
                destination.addDataPoint(dp.getNumericalValues(), dp.getCategoricalValues(), dataSet.getDataPointCategory(i));
            }
        }
        return destination;
    }

    public static ClassificationDataSet getWeightSampledDataSet(ClassificationDataSet dataSet, int[] sampledCounts) {
        ClassificationDataSet destination = new ClassificationDataSet(dataSet.getNumNumericalVars(), dataSet.getCategories(), dataSet.getPredicting());
        for (int i = 0; i < sampledCounts.length; ++i) {
            if (sampledCounts[i] <= 0) continue;
            DataPoint dp = dataSet.getDataPoint(i);
            destination.addDataPoint(dp, dataSet.getDataPointCategory(i), dataSet.getWeight(i) * (double)sampledCounts[i]);
        }
        return destination;
    }

    public static RegressionDataSet getSampledDataSet(RegressionDataSet dataSet, int[] sampledCounts) {
        RegressionDataSet destination = new RegressionDataSet(dataSet.getNumNumericalVars(), dataSet.getCategories());
        for (int i = 0; i < sampledCounts.length; ++i) {
            for (int j = 0; j < sampledCounts[i]; ++j) {
                DataPoint dp = dataSet.getDataPoint(i);
                destination.addDataPoint(dp, dataSet.getTargetValue(i));
            }
        }
        return destination;
    }

    public static RegressionDataSet getWeightSampledDataSet(RegressionDataSet dataSet, int[] sampledCounts) {
        RegressionDataSet destination = new RegressionDataSet(dataSet.getNumNumericalVars(), dataSet.getCategories());
        for (int i = 0; i < sampledCounts.length; ++i) {
            if (sampledCounts[i] <= 0) continue;
            DataPoint dp = dataSet.getDataPoint(i);
            destination.addDataPoint(dp, dataSet.getTargetValue(i), dataSet.getWeight(i) * (double)sampledCounts[i]);
        }
        return destination;
    }

    public static void sampleWithReplacement(int[] sampleCounts, int samples, Random rand) {
        Arrays.fill(sampleCounts, 0);
        for (int i = 0; i < samples; ++i) {
            int n = rand.nextInt(sampleCounts.length);
            sampleCounts[n] = sampleCounts[n] + 1;
        }
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    @Override
    public double regress(DataPoint data) {
        if (this.baseRegressor == null) {
            throw new RuntimeException("Bagging instance created for classification, not regression");
        }
        if (this.learners == null || this.learners.isEmpty()) {
            throw new RuntimeException("Regressor has not yet been trained");
        }
        OnLineStatistics stats = new OnLineStatistics();
        for (int i = 0; i < this.learners.size(); ++i) {
            double x = ((Regressor)this.learners.get(i)).regress(data);
            stats.add(x);
        }
        return stats.getMean();
    }

    @Override
    public void train(RegressionDataSet dataSet, boolean parallel) {
        this.learners = new ArrayList(this.rounds);
        Semaphore waitForThread = new Semaphore(SystemInfo.LogicalCores);
        CountDownLatch waitForFinish = new CountDownLatch(this.rounds);
        List synchronizedLearners = Collections.synchronizedList(this.learners);
        int[] sampleCount = new int[dataSet.size()];
        ExecutorService threadPool = ParallelUtils.getNewExecutor(parallel);
        for (int i = 0; i < this.rounds; ++i) {
            Bagging.sampleWithReplacement(sampleCount, sampleCount.length + this.extraSamples, this.random);
            RegressionDataSet sampleSet = Bagging.getSampledDataSet(dataSet, sampleCount);
            Regressor learner = this.baseRegressor.clone();
            if (this.simultaniousTraining && parallel) {
                try {
                    waitForThread.acquire();
                    threadPool.submit(() -> {
                        learner.train(sampleSet);
                        synchronizedLearners.add(learner);
                        waitForThread.release();
                        waitForFinish.countDown();
                    });
                }
                catch (InterruptedException ex) {
                    Logger.getLogger(Bagging.class.getName()).log(Level.SEVERE, null, ex);
                    System.err.println(ex.getMessage());
                }
                continue;
            }
            learner.train(sampleSet, parallel);
            this.learners.add(learner);
        }
        if (this.simultaniousTraining && parallel) {
            try {
                waitForFinish.await();
            }
            catch (InterruptedException ex) {
                Logger.getLogger(Bagging.class.getName()).log(Level.SEVERE, null, ex);
                System.err.println(ex.getMessage());
            }
        }
        threadPool.shutdownNow();
    }

    @Override
    public Bagging clone() {
        Bagging clone = new Bagging(this.extraSamples, this.simultaniousTraining, this.rounds, new Random(this.rounds));
        if (this.baseClassifier != null) {
            clone.baseClassifier = this.baseClassifier.clone();
        }
        if (this.predicting != null) {
            clone.predicting = this.predicting.clone();
        }
        if (this.baseRegressor != null) {
            clone.baseRegressor = this.baseRegressor.clone();
        }
        if (this.learners != null && !this.learners.isEmpty()) {
            clone.learners = new ArrayList(this.learners.size());
            for (Object learner : this.learners) {
                if (learner instanceof Classifier) {
                    clone.learners.add(((Classifier)learner).clone());
                    continue;
                }
                clone.learners.add(((Regressor)learner).clone());
            }
        }
        return clone;
    }
}

