/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.linear;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import jsat.DataSet;
import jsat.SimpleWeightVectorModel;
import jsat.SingleWeightVectorModel;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.WarmClassifier;
import jsat.classifiers.evaluation.ClassificationScore;
import jsat.classifiers.linear.LogisticRegressionDCD;
import jsat.classifiers.svm.DCDs;
import jsat.datatransform.ProjectionTransform;
import jsat.exceptions.FailedToFitException;
import jsat.linear.DenseMatrix;
import jsat.linear.DenseVector;
import jsat.linear.Matrix;
import jsat.linear.Vec;
import jsat.lossfunctions.LogisticLoss;
import jsat.math.OnLineStatistics;
import jsat.parameters.GridSearch;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.regression.WarmRegressor;
import jsat.regression.evaluation.RegressionScore;
import jsat.utils.SystemInfo;
import jsat.utils.concurrent.ParallelUtils;
import jsat.utils.random.RandomUtil;

public class OWA
implements Classifier,
Regressor,
Parameterized,
SingleWeightVectorModel,
WarmClassifier {
    protected int min_points_per_core = 1000;
    protected int sample_multipler = 3;
    @Parameter.ParameterHolder
    private SingleWeightVectorModel base_learner;
    private boolean estimate_cv_scores = false;
    private List<ClassificationScore> scores_c = new ArrayList<ClassificationScore>();
    private List<RegressionScore> scores_r = new ArrayList<RegressionScore>();
    private List<OnLineStatistics> scores_stats = new ArrayList<OnLineStatistics>();
    private boolean warmTraining = false;
    private List<SimpleWeightVectorModel> prev_solutions = null;
    protected Vec w;
    protected double bias;

    public OWA(SingleWeightVectorModel base_learner) {
        this.base_learner = base_learner;
    }

    public OWA(OWA toClone) {
        this.min_points_per_core = toClone.min_points_per_core;
        this.sample_multipler = toClone.sample_multipler;
        this.base_learner = toClone.base_learner.clone();
        this.estimate_cv_scores = toClone.estimate_cv_scores;
        if (toClone.w != null) {
            this.w = toClone.w.clone();
            this.bias = toClone.bias;
        }
        this.scores_c = toClone.scores_c.stream().map(s -> s.clone()).collect(Collectors.toList());
        this.scores_r = toClone.scores_r.stream().map(s -> s.clone()).collect(Collectors.toList());
        this.scores_stats = toClone.scores_stats.stream().map(s -> s.clone()).collect(Collectors.toList());
        this.warmTraining = toClone.warmTraining;
        if (toClone.prev_solutions != null) {
            this.prev_solutions = toClone.prev_solutions.stream().map(s -> s.clone()).collect(Collectors.toList());
        }
    }

    public void setEstimateCV(boolean estimate_cv_scores) {
        this.estimate_cv_scores = estimate_cv_scores;
    }

    public boolean issetEstimateCV() {
        return this.estimate_cv_scores;
    }

    public void setWarmTraining(boolean warmTraining) {
        this.warmTraining = warmTraining;
    }

    public boolean isWarmTraining() {
        return this.warmTraining;
    }

    public void addScore(ClassificationScore score) {
        this.scores_c.add(score);
    }

    public Map<ClassificationScore, OnLineStatistics> getScoreStatsC() {
        HashMap<ClassificationScore, OnLineStatistics> results = new HashMap<ClassificationScore, OnLineStatistics>();
        for (int i = 0; i < Math.min(this.scores_c.size(), this.scores_stats.size()); ++i) {
            results.put(this.scores_c.get(i), this.scores_stats.get(i));
        }
        return results;
    }

    public Map<RegressionScore, OnLineStatistics> getScoreStatsR() {
        HashMap<RegressionScore, OnLineStatistics> results = new HashMap<RegressionScore, OnLineStatistics>();
        for (int i = 0; i < Math.min(this.scores_r.size(), this.scores_stats.size()); ++i) {
            results.put(this.scores_r.get(i), this.scores_stats.get(i));
        }
        return results;
    }

    public void addScore(RegressionScore score) {
        this.scores_r.add(score);
    }

    private void trainWork(int requested_cores, DataSet dataSet, boolean parallel, Object warmSolution) {
        Parameterized Z_model;
        DataSet Z_owa;
        ArrayList<Object> warm_starts;
        int m;
        int d = dataSet.getNumFeatures();
        int N = dataSet.size();
        int n = m = requested_cores <= 0 ? Math.min(Math.min(SystemInfo.LogicalCores, dataSet.size() / this.min_points_per_core), d / 2 + 1) : requested_cores;
        if (this.warmTraining && warmSolution != null) {
            if (!(this.base_learner instanceof WarmClassifier) && !(this.base_learner instanceof WarmRegressor)) {
                throw new FailedToFitException("Base class " + this.base_learner.getClass().getSimpleName() + " can not be trained via warm starts");
            }
            warm_starts = new ArrayList<Object>();
            if (warmSolution instanceof OWA && ((OWA)warmSolution).prev_solutions != null) {
                for (SimpleWeightVectorModel sol : ((OWA)warmSolution).prev_solutions) {
                    warm_starts.add(sol);
                }
            } else {
                warm_starts.add(warmSolution);
            }
            while (warm_starts.size() < m) {
                warm_starts.add(warm_starts.get(warm_starts.size() - 1));
            }
        } else {
            warm_starts = null;
        }
        List splits = dataSet.cvSet(m, RandomUtil.getRandom(m * dataSet.size()));
        List erms = ParallelUtils.streamP(IntStream.range(0, splits.size()), parallel).mapToObj(i -> {
            Object warm_w_i;
            DataSet data = (DataSet)splits.get(i);
            SingleWeightVectorModel w_i = this.base_learner.clone();
            Object v0 = warm_w_i = warm_starts == null ? null : warm_starts.get(i);
            if (w_i instanceof Classifier) {
                if (w_i instanceof WarmClassifier && warm_w_i != null) {
                    ((WarmClassifier)((Object)w_i)).train((ClassificationDataSet)data, warm_w_i, false);
                } else {
                    ((Classifier)((Object)w_i)).train((ClassificationDataSet)data);
                }
            } else if (w_i instanceof WarmClassifier && warm_w_i != null) {
                ((WarmRegressor)((Object)w_i)).train((RegressionDataSet)data, warm_w_i, false);
            } else {
                ((Regressor)((Object)w_i)).train((RegressionDataSet)data);
            }
            return w_i;
        }).collect(Collectors.toList());
        if (this.warmTraining) {
            this.prev_solutions = erms;
        }
        DenseMatrix W = new DenseMatrix(m, d);
        DenseVector b = new DenseVector(m);
        for (int i2 = 0; i2 < m; ++i2) {
            SingleWeightVectorModel w_i = (SingleWeightVectorModel)erms.get(i2);
            w_i.getRawWeight().copyToRow(W, i2);
            ((Vec)b).set(i2, w_i.getBias());
        }
        ProjectionTransform t = new ProjectionTransform(W, b);
        double sub_sample_frac = Math.min(Math.max((double)(this.sample_multipler * m) / (double)d + 40.0 / (double)N, (double)((m + 40) * this.sample_multipler) / (double)(N / m)), 1.0);
        List<ClassificationDataSet> Z_owa_splits = splits.parallelStream().map(data -> {
            DataSet z_i = (DataSet)data.randomSplit(sub_sample_frac).get(0);
            if (!data.rowMajor()) {
                Iterator<DataPoint> orig_iter = data.getDataPointIterator();
                int pos = 0;
                if (data instanceof ClassificationDataSet) {
                    ClassificationDataSet new_data = new ClassificationDataSet(W.rows(), new CategoricalData[0], ((ClassificationDataSet)z_i).getPredicting());
                    while (orig_iter.hasNext()) {
                        new_data.addDataPoint(t.transform(orig_iter.next()), ((ClassificationDataSet)z_i).getDataPointCategory(pos++));
                    }
                    z_i = new_data;
                } else {
                    RegressionDataSet new_data = new RegressionDataSet(W.rows(), new CategoricalData[0]);
                    while (orig_iter.hasNext()) {
                        new_data.addDataPoint(t.transform(orig_iter.next()), ((RegressionDataSet)z_i).getTargetValue(pos++));
                    }
                    z_i = new_data;
                }
            } else {
                z_i.applyTransform(t);
            }
            return z_i;
        }).collect(Collectors.toList());
        if (this.estimate_cv_scores) {
            this.scores_stats.clear();
            ParallelUtils.streamP(IntStream.range(0, m), true).forEach(id -> {
                SimpleWeightVectorModel cv_model;
                Parameterized Z_model;
                DataSet Z_owa_mi;
                if (dataSet instanceof ClassificationDataSet) {
                    Z_owa_mi = ClassificationDataSet.comineAllBut(Z_owa_splits, id);
                    LogisticRegressionDCD lr = new LogisticRegressionDCD();
                    lr.setUseBias(false);
                    Z_model = lr;
                } else {
                    Z_owa_mi = RegressionDataSet.comineAllBut(Z_owa_splits, id);
                    DCDs dcd = new DCDs();
                    dcd.setUseBias(false);
                    Z_model = dcd;
                }
                Z_owa_mi.applyTransform(dp -> {
                    Vec v = dp.getNumericalValues().clone();
                    v.set(id, 0.0);
                    return new DataPoint(v);
                });
                GridSearch rs = new GridSearch((Classifier)((Object)Z_model), 5);
                rs.setUseWarmStarts(true);
                rs.autoAddParameters(Z_owa_mi, 9);
                rs.setTrainModelsInParallel(false);
                rs.setTrainFinalModel(true);
                if (Z_owa_mi instanceof ClassificationDataSet) {
                    rs.train((ClassificationDataSet)Z_owa_mi, false);
                    cv_model = (SimpleWeightVectorModel)((Object)rs.getTrainedClassifier());
                } else {
                    rs.train((RegressionDataSet)Z_owa_mi, false);
                    cv_model = (SimpleWeightVectorModel)((Object)rs.getTrainedRegressor());
                }
                cv_model.getRawWeight(0).set(id, 0.0);
                DenseVector w_mi = new DenseVector(d);
                DenseVector b_mi = new DenseVector(1);
                this.accumulateUpdates(m, cv_model, w_mi, b_mi, W, b);
                if (Z_owa_mi instanceof ClassificationDataSet) {
                    Object result;
                    ClassificationDataSet cds = (ClassificationDataSet)splits.get(id);
                    List scores = this.scores_c.stream().map(s -> s.clone()).collect(Collectors.toList());
                    for (ClassificationScore s2 : scores) {
                        s2.prepare(cds.getPredicting());
                    }
                    int pos = 0;
                    Iterator<DataPoint> iter = cds.getDataPointIterator();
                    Vec weights = cds.getDataWeights();
                    while (iter.hasNext()) {
                        result = LogisticLoss.classify(((Vec)w_mi).dot(iter.next().getNumericalValues()) + ((Vec)b_mi).get(0));
                        for (ClassificationScore s3 : scores) {
                            s3.addResult((CategoricalResults)result, cds.getDataPointCategory(pos), weights.get(pos));
                        }
                        ++pos;
                    }
                    result = this.scores_stats;
                    synchronized (result) {
                        if (this.scores_stats.isEmpty()) {
                            for (ClassificationScore s3 : scores) {
                                this.scores_stats.add(new OnLineStatistics());
                            }
                        }
                        for (int i = 0; i < scores.size(); ++i) {
                            this.scores_stats.get(i).add(((ClassificationScore)scores.get(i)).getScore());
                        }
                    }
                }
                RegressionDataSet rds = (RegressionDataSet)splits.get(id);
                List scores = this.scores_r.stream().map(s -> s.clone()).collect(Collectors.toList());
                int pos = 0;
                Iterator<DataPoint> iter = rds.getDataPointIterator();
                Vec weights = rds.getDataWeights();
                while (iter.hasNext()) {
                    double result = ((Vec)w_mi).dot(iter.next().getNumericalValues()) + ((Vec)b_mi).get(0);
                    for (RegressionScore s4 : scores) {
                        s4.addResult(result, rds.getTargetValue(pos), weights.get(pos));
                    }
                    ++pos;
                }
                List<OnLineStatistics> list = this.scores_stats;
                synchronized (list) {
                    if (this.scores_stats.isEmpty()) {
                        for (Object s3 : scores) {
                            this.scores_stats.add(new OnLineStatistics());
                        }
                    }
                    for (int i = 0; i < scores.size(); ++i) {
                        this.scores_stats.get(i).add(((RegressionScore)scores.get(i)).getScore());
                    }
                }
            });
        }
        if (dataSet instanceof ClassificationDataSet) {
            Z_owa = ClassificationDataSet.comineAllBut(Z_owa_splits, -1);
            LogisticRegressionDCD lr = new LogisticRegressionDCD();
            lr.setUseBias(false);
            Z_model = lr;
        } else {
            Z_owa = RegressionDataSet.comineAllBut(Z_owa_splits, -1);
            DCDs dcd = new DCDs();
            dcd.setUseBias(false);
            Z_model = dcd;
        }
        GridSearch rs = new GridSearch((Classifier)((Object)Z_model), 5);
        rs.setUseWarmStarts(false);
        rs.autoAddParameters(Z_owa, 9);
        rs.setTrainModelsInParallel(true);
        rs.setTrainFinalModel(true);
        if (Z_owa instanceof ClassificationDataSet) {
            rs.train((ClassificationDataSet)Z_owa, parallel);
        } else {
            rs.train((RegressionDataSet)Z_owa, parallel);
        }
        SimpleWeightVectorModel weight_model = (SimpleWeightVectorModel)((Object)rs.getTrainedClassifier());
        DenseVector w_final = new DenseVector(d);
        DenseVector b_final = new DenseVector(1);
        this.accumulateUpdates(m, weight_model, w_final, b_final, W, b);
        this.w = w_final;
        this.bias = ((Vec)b_final).get(0);
    }

    private void accumulateUpdates(int m, SimpleWeightVectorModel w_i_weights_source, Vec w_final, Vec b_final, Matrix W, Vec b) {
        Vec w_i_weights = w_i_weights_source.getRawWeight(0).clone();
        if (w_i_weights.min() >= 0.0) {
            w_i_weights.mutableDivide(w_i_weights.sum());
        }
        for (int i = 0; i < m; ++i) {
            w_final.mutableAdd(w_i_weights.get(i), W.getRowView(i));
            b_final.increment(0, w_i_weights.get(i) * b.get(i));
        }
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        return LogisticLoss.classify(this.w.dot(data.getNumericalValues()) + this.bias);
    }

    @Override
    public void train(ClassificationDataSet dataSet, boolean parallel) {
        Classifier c_base = (Classifier)((Object)this.base_learner);
        if (!parallel) {
            c_base.train(dataSet);
            this.w = this.base_learner.getRawWeight();
            this.bias = this.base_learner.getBias();
            return;
        }
        this.trainWork(-1, dataSet, parallel, null);
    }

    @Override
    public boolean supportsWeightedData() {
        if (this.base_learner instanceof Classifier) {
            return ((Classifier)((Object)this.base_learner)).supportsWeightedData();
        }
        return ((Regressor)((Object)this.base_learner)).supportsWeightedData();
    }

    @Override
    public double regress(DataPoint data) {
        return this.w.dot(data.getNumericalValues()) + this.bias;
    }

    @Override
    public void train(RegressionDataSet dataSet, boolean parallel) {
        Regressor r_base = (Regressor)((Object)this.base_learner);
        if (!parallel) {
            r_base.train(dataSet);
            this.w = this.base_learner.getRawWeight();
            this.bias = this.base_learner.getBias();
            return;
        }
        this.trainWork(-1, dataSet, parallel, null);
    }

    @Override
    public Vec getRawWeight() {
        return this.w;
    }

    @Override
    public double getBias() {
        return this.bias;
    }

    @Override
    public OWA clone() {
        return new OWA(this);
    }

    @Override
    public boolean warmFromSameDataOnly() {
        if (this.base_learner instanceof WarmClassifier) {
            return ((WarmClassifier)((Object)this.base_learner)).warmFromSameDataOnly();
        }
        if (this.base_learner instanceof WarmRegressor) {
            return ((WarmRegressor)((Object)this.base_learner)).warmFromSameDataOnly();
        }
        return false;
    }

    @Override
    public void train(ClassificationDataSet dataSet, Classifier warmSolution, boolean parallel) {
        Classifier c_base = (Classifier)((Object)this.base_learner);
        if (!parallel) {
            c_base.train(dataSet);
            this.w = this.base_learner.getRawWeight();
            this.bias = this.base_learner.getBias();
            return;
        }
        this.trainWork(-1, dataSet, parallel, warmSolution);
    }
}

