/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.boosting;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import jsat.classifiers.BaseUpdateableClassifier;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.DataPoint;
import jsat.classifiers.UpdateableClassifier;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.regression.BaseUpdateableRegressor;
import jsat.regression.RegressionDataSet;
import jsat.regression.UpdateableRegressor;

public class UpdatableStacking
implements UpdateableClassifier,
UpdateableRegressor {
    private static final long serialVersionUID = -5111303510263114862L;
    private int weightsPerModel;
    private UpdateableClassifier aggregatingClassifier;
    private List<UpdateableClassifier> baseClassifiers;
    private UpdateableRegressor aggregatingRegressor;
    private List<UpdateableRegressor> baseRegressors;

    public UpdatableStacking(UpdateableClassifier aggregatingClassifier, List<UpdateableClassifier> baseClassifiers) {
        if (baseClassifiers.size() < 2) {
            throw new IllegalArgumentException("base classifiers must contain at least 2 elements, not " + baseClassifiers.size());
        }
        this.aggregatingClassifier = aggregatingClassifier;
        this.baseClassifiers = baseClassifiers;
        boolean allRegressors = aggregatingClassifier instanceof UpdateableRegressor;
        for (UpdateableClassifier cl : baseClassifiers) {
            if (cl instanceof UpdateableRegressor) continue;
            allRegressors = false;
        }
        if (allRegressors) {
            this.aggregatingRegressor = (UpdateableRegressor)((Object)aggregatingClassifier);
            this.baseRegressors = baseClassifiers;
        }
    }

    public UpdatableStacking(UpdateableClassifier aggregatingClassifier, UpdateableClassifier ... baseClassifiers) {
        this(aggregatingClassifier, Arrays.asList(baseClassifiers));
    }

    public UpdatableStacking(UpdateableRegressor aggregatingRegressor, List<UpdateableRegressor> baseRegressors) {
        this.aggregatingRegressor = aggregatingRegressor;
        this.baseRegressors = baseRegressors;
        boolean allClassifiers = aggregatingRegressor instanceof UpdateableClassifier;
        for (UpdateableRegressor reg : baseRegressors) {
            if (reg instanceof UpdateableClassifier) continue;
            allClassifiers = false;
        }
        if (allClassifiers) {
            this.aggregatingClassifier = (UpdateableClassifier)((Object)aggregatingRegressor);
            this.baseClassifiers = baseRegressors;
        }
    }

    public UpdatableStacking(UpdateableRegressor aggregatingRegressor, UpdateableRegressor ... baseRegressors) {
        this(aggregatingRegressor, Arrays.asList(baseRegressors));
    }

    public UpdatableStacking(UpdatableStacking toCopy) {
        this.weightsPerModel = toCopy.weightsPerModel;
        if (toCopy.aggregatingClassifier != null) {
            this.aggregatingClassifier = toCopy.aggregatingClassifier.clone();
            this.baseClassifiers = new ArrayList<UpdateableClassifier>(toCopy.baseClassifiers.size());
            for (UpdateableClassifier bc : toCopy.baseClassifiers) {
                this.baseClassifiers.add(bc.clone());
            }
            if (toCopy.aggregatingRegressor == toCopy.aggregatingClassifier) {
                this.aggregatingRegressor = (UpdateableRegressor)((Object)this.aggregatingClassifier);
                this.baseRegressors = this.baseClassifiers;
            }
        } else {
            this.aggregatingRegressor = toCopy.aggregatingRegressor.clone();
            this.baseRegressors = new ArrayList<UpdateableRegressor>(toCopy.baseRegressors.size());
            for (UpdateableRegressor br : toCopy.baseRegressors) {
                this.baseRegressors.add(br.clone());
            }
        }
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        return this.aggregatingClassifier.classify(this.getPredVecC(data));
    }

    private DataPoint getPredVecC(DataPoint data) {
        DenseVector w = new DenseVector(this.weightsPerModel * this.baseClassifiers.size());
        if (this.weightsPerModel == 1) {
            for (int i = 0; i < this.baseClassifiers.size(); ++i) {
                ((Vec)w).set(i, this.baseClassifiers.get(i).classify(data).getProb(0) * 2.0 - 1.0);
            }
        } else {
            for (int i = 0; i < this.baseClassifiers.size(); ++i) {
                CategoricalResults pred = this.baseClassifiers.get(i).classify(data);
                for (int j = 0; j < this.weightsPerModel; ++j) {
                    ((Vec)w).set(i * this.weightsPerModel + j, pred.getProb(j));
                }
            }
        }
        return new DataPoint(w);
    }

    private DataPoint getPredVecR(DataPoint data) {
        DenseVector w = new DenseVector(this.baseRegressors.size());
        for (int i = 0; i < this.baseRegressors.size(); ++i) {
            ((Vec)w).set(i, this.baseRegressors.get(i).regress(data));
        }
        return new DataPoint(w);
    }

    @Override
    public void setUp(CategoricalData[] categoricalAttributes, int numericAttributes, CategoricalData predicting) {
        int C2 = predicting.getNumOfCategories();
        this.weightsPerModel = C2 == 2 ? 1 : C2;
        this.aggregatingClassifier.setUp(new CategoricalData[0], this.weightsPerModel * this.baseClassifiers.size(), predicting);
        for (UpdateableClassifier uc : this.baseClassifiers) {
            uc.setUp(categoricalAttributes, numericAttributes, predicting);
        }
    }

    @Override
    public void update(DataPoint dataPoint, double weight, int targetClass) {
        this.aggregatingClassifier.update(this.getPredVecC(dataPoint), weight, targetClass);
        for (UpdateableClassifier uc : this.baseClassifiers) {
            uc.update(dataPoint, targetClass);
        }
    }

    @Override
    public void setUp(CategoricalData[] categoricalAttributes, int numericAttributes) {
        this.weightsPerModel = 1;
        this.aggregatingRegressor.setUp(new CategoricalData[0], this.weightsPerModel * this.baseRegressors.size());
        for (UpdateableRegressor ur : this.baseRegressors) {
            ur.setUp(categoricalAttributes, numericAttributes);
        }
    }

    @Override
    public void update(DataPoint dataPoint, double weight, double targetValue) {
        this.aggregatingRegressor.update(this.getPredVecR(dataPoint), weight, targetValue);
        for (UpdateableRegressor ur : this.baseRegressors) {
            ur.update(dataPoint, targetValue);
        }
    }

    @Override
    public void train(ClassificationDataSet dataSet, boolean parallel) {
        this.train(dataSet);
    }

    @Override
    public void train(ClassificationDataSet dataSet) {
        BaseUpdateableClassifier.trainEpochs(dataSet, this, 1);
    }

    @Override
    public boolean supportsWeightedData() {
        if (this.aggregatingClassifier != null) {
            return this.aggregatingClassifier.supportsWeightedData();
        }
        return this.aggregatingRegressor.supportsWeightedData();
    }

    @Override
    public double regress(DataPoint data) {
        return this.aggregatingRegressor.regress(this.getPredVecR(data));
    }

    @Override
    public void train(RegressionDataSet dataSet, boolean parallel) {
        this.train(dataSet);
    }

    @Override
    public void train(RegressionDataSet dataSet) {
        BaseUpdateableRegressor.trainEpochs(dataSet, this, 1);
    }

    @Override
    public UpdatableStacking clone() {
        return new UpdatableStacking(this);
    }
}

