/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.neuralnetwork;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import jsat.DataSet;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.distributions.Normal;
import jsat.linear.DenseMatrix;
import jsat.linear.DenseVector;
import jsat.linear.Matrix;
import jsat.linear.Vec;
import jsat.math.Function1D;
import jsat.math.decayrates.DecayRate;
import jsat.math.decayrates.ExponetialDecay;
import jsat.parameters.IntParameter;
import jsat.parameters.ObjectParameter;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.utils.IntList;
import jsat.utils.ListUtils;
import jsat.utils.random.RandomUtil;

public class BackPropagationNet
implements Classifier,
Regressor,
Parameterized {
    private static final long serialVersionUID = 335438198218313862L;
    private int inputSize;
    private int outputSize;
    private ActivationFunction f = softsignActiv;
    private DecayRate learningRateDecay = new ExponetialDecay();
    private double momentum = 0.1;
    private double weightDecay = 0.0;
    private int epochs = 1000;
    private double initialLearningRate = 0.2;
    private WeightInitialization weightInitialization = WeightInitialization.TANH_NORMALIZED_INITIALIZATION;
    private double targetBump = 0.1;
    private int batchSize = 10;
    private int[] npl;
    private List<Matrix> Ws;
    private List<Vec> bs;
    private double targetMax;
    private double targetMin;
    private double targetMultiplier;
    public static final ActivationFunction logitActiv = new ActivationFunction(){
        private static final long serialVersionUID = -5675881412853268432L;

        @Override
        public double response(double x) {
            return 1.0 / (1.0 + Math.exp(-x));
        }

        @Override
        public double min() {
            return 0.0;
        }

        @Override
        public double max() {
            return 1.0;
        }

        @Override
        public Function1D getD() {
            return logitPrime;
        }

        public String toString() {
            return "Logit";
        }
    };
    private static final Function1D logitPrime = x -> x * (1.0 - x);
    public static final ActivationFunction tanhActiv = new ActivationFunction(){
        private static final long serialVersionUID = 5531922338473526216L;

        @Override
        public double response(double x) {
            return Math.tanh(x);
        }

        @Override
        public double min() {
            return -1.0;
        }

        @Override
        public double max() {
            return 1.0;
        }

        @Override
        public Function1D getD() {
            return x -> 1.0 - x * x;
        }

        public String toString() {
            return "Tanh";
        }
    };
    public static final ActivationFunction softsignActiv = new ActivationFunction(){
        private static final long serialVersionUID = 1618447580574194519L;

        @Override
        public double response(double x) {
            return x / (1.0 + Math.abs(x));
        }

        @Override
        public double min() {
            return -1.0;
        }

        @Override
        public double max() {
            return 1.0;
        }

        @Override
        public Function1D getD() {
            return x -> Math.pow(1.0 - Math.abs(x), 2.0);
        }

        public String toString() {
            return "Softsign";
        }
    };

    public BackPropagationNet() {
        this(1024);
    }

    public BackPropagationNet(int ... npl) {
        if (npl.length < 1) {
            throw new IllegalArgumentException("There must be at least one hidden layer");
        }
        this.npl = npl;
    }

    protected BackPropagationNet(BackPropagationNet toClone) {
        this(Arrays.copyOf(toClone.npl, toClone.npl.length));
        int i;
        this.inputSize = toClone.inputSize;
        this.outputSize = toClone.outputSize;
        this.f = toClone.f;
        this.momentum = toClone.momentum;
        this.weightDecay = toClone.weightDecay;
        this.epochs = toClone.epochs;
        this.initialLearningRate = toClone.initialLearningRate;
        this.learningRateDecay = toClone.learningRateDecay;
        this.weightInitialization = toClone.weightInitialization;
        this.targetBump = toClone.targetBump;
        this.targetMax = toClone.targetMax;
        this.targetMin = toClone.targetMin;
        this.targetMultiplier = toClone.targetMultiplier;
        this.batchSize = toClone.batchSize;
        if (toClone.Ws != null) {
            this.Ws = new ArrayList<Matrix>(toClone.Ws);
            for (i = 0; i < this.Ws.size(); ++i) {
                this.Ws.set(i, this.Ws.get(i).clone());
            }
        }
        if (toClone.bs != null) {
            this.bs = new ArrayList<Vec>(toClone.bs);
            for (i = 0; i < this.bs.size(); ++i) {
                this.bs.set(i, this.bs.get(i).clone());
            }
        }
    }

    private void trainNN(DataSet dataSet) {
        ArrayList activations = new ArrayList(this.batchSize);
        ArrayList derivatives = new ArrayList(this.batchSize);
        ArrayList deltas = new ArrayList(this.batchSize);
        ArrayList<DenseMatrix> updates = new ArrayList<DenseMatrix>(this.Ws.size());
        ArrayList<Vec> cur_x = new ArrayList<Vec>(this.batchSize);
        ArrayList prev_x = new ArrayList(this.batchSize);
        for (int i = 0; i < this.batchSize; ++i) {
            activations.add(new ArrayList(this.Ws.size()));
            derivatives.add(new ArrayList(this.Ws.size()));
            deltas.add(new ArrayList(this.Ws.size()));
            for (Matrix w : this.Ws) {
                int L = w.rows();
                ((List)activations.get(i)).add(new DenseVector(L));
                ((List)derivatives.get(i)).add(new DenseVector(L));
                ((List)deltas.get(i)).add(new DenseVector(L));
                if (i != 0) continue;
                updates.add(new DenseMatrix(w.rows(), w.cols()));
            }
        }
        IntList iterOrder = new IntList(dataSet.size());
        ListUtils.addRange(iterOrder, 0, dataSet.size(), 1);
        double bSizeInv = 1.0 / (double)this.batchSize;
        for (int epoch = 0; epoch < this.epochs; ++epoch) {
            Collections.shuffle(iterOrder);
            double eta = this.learningRateDecay.rate(epoch, this.epochs, this.initialLearningRate);
            double error = 0.0;
            for (int iter = 0; iter < dataSet.size(); iter += this.batchSize) {
                int bi;
                if (dataSet.size() - iter < this.batchSize) continue;
                cur_x.clear();
                for (bi = 0; bi < this.batchSize; ++bi) {
                    int idx = iterOrder.get(iter + bi);
                    Vec x = dataSet.getDataPoint(idx).getNumericalValues();
                    cur_x.add(x);
                    this.feedForward(x, (List)activations.get(bi), (List)derivatives.get(bi));
                    Vec delta_out = (Vec)((List)deltas.get(bi)).get(this.npl.length);
                    Vec a_i = (Vec)((List)activations.get(bi)).get(this.npl.length);
                    Vec d_i = (Vec)((List)derivatives.get(bi)).get(this.npl.length);
                    error += this.computeOutputDelta(dataSet, idx, delta_out, a_i, d_i);
                }
                for (bi = 0; bi < this.batchSize; ++bi) {
                    int i;
                    for (i = this.Ws.size() - 2; i >= 0; --i) {
                        Vec delta = (Vec)((List)deltas.get(bi)).get(i);
                        delta.zeroOut();
                        Matrix W = this.Ws.get(i + 1);
                        W.transposeMultiply(1.0, (Vec)((List)deltas.get(bi)).get(i + 1), delta);
                        delta.mutablePairwiseMultiply((Vec)((List)derivatives.get(bi)).get(i));
                    }
                    for (i = 1; i < this.Ws.size(); ++i) {
                        Matrix W = this.Ws.get(i);
                        Vec b = this.bs.get(i);
                        W.mutableSubtract(eta * this.weightDecay, W);
                        if (this.momentum != 0.0) {
                            Matrix update = (Matrix)updates.get(i);
                            update.mutableMultiply(this.momentum);
                            Matrix.OuterProductUpdate(update, (Vec)((List)deltas.get(bi)).get(i), (Vec)((List)activations.get(bi)).get(i - 1), -eta * bSizeInv);
                            W.mutableAdd(update);
                        } else {
                            Matrix.OuterProductUpdate(W, (Vec)((List)deltas.get(bi)).get(i), (Vec)((List)activations.get(bi)).get(i - 1), -eta * bSizeInv);
                        }
                        b.mutableAdd(-eta * bSizeInv, (Vec)((List)deltas.get(bi)).get(i));
                    }
                    Matrix W = this.Ws.get(0);
                    W.mutableSubtract(eta * this.weightDecay, W);
                    Vec b = this.bs.get(0);
                    if (this.momentum != 0.0) {
                        Matrix update = (Matrix)updates.get(0);
                        update.mutableMultiply(this.momentum);
                        Matrix.OuterProductUpdate(update, (Vec)((List)deltas.get(bi)).get(0), (Vec)cur_x.get(bi), -eta * bSizeInv);
                        W.mutableAdd(update);
                    } else {
                        Matrix.OuterProductUpdate(W, (Vec)((List)deltas.get(bi)).get(0), (Vec)cur_x.get(bi), -eta * bSizeInv);
                    }
                    b.mutableAdd(-eta * bSizeInv, (Vec)((List)deltas.get(bi)).get(0));
                }
            }
        }
    }

    public void setMomentum(double momentum) {
        if (momentum < 0.0 || Double.isNaN(momentum) || Double.isInfinite(momentum)) {
            throw new ArithmeticException("Momentum must be non negative, not " + momentum);
        }
        this.momentum = momentum;
    }

    public double getMomentum() {
        return this.momentum;
    }

    public void setInitialLearningRate(double initialLearningRate) {
        if (initialLearningRate <= 0.0 || Double.isNaN(initialLearningRate) || Double.isInfinite(initialLearningRate)) {
            throw new ArithmeticException("Learning rate must be a positive cosntant, not " + initialLearningRate);
        }
        this.initialLearningRate = initialLearningRate;
    }

    public double getInitialLearningRate() {
        return this.initialLearningRate;
    }

    public void setLearningRateDecay(DecayRate learningRateDecay) {
        this.learningRateDecay = learningRateDecay;
    }

    public DecayRate getLearningRateDecay() {
        return this.learningRateDecay;
    }

    public void setEpochs(int epochs) {
        if (epochs < 1) {
            throw new ArithmeticException("number of training epochs must be positive, not " + epochs);
        }
        this.epochs = epochs;
    }

    public int getEpochs() {
        return this.epochs;
    }

    public void setWeightDecay(double weightDecay) {
        if (weightDecay < 0.0 || weightDecay >= 1.0 || Double.isNaN(weightDecay)) {
            throw new ArithmeticException("Weight decay must be in [0,1), not " + weightDecay);
        }
        this.weightDecay = weightDecay;
    }

    public double getWeightDecay() {
        return this.weightDecay;
    }

    public void setWeightInitialization(WeightInitialization weightInitialization) {
        this.weightInitialization = weightInitialization;
    }

    public WeightInitialization getWeightInitialization() {
        return this.weightInitialization;
    }

    public void setBatchSize(int batchSize) {
        this.batchSize = batchSize;
    }

    public int getBatchSize() {
        return this.batchSize;
    }

    public void setActivationFunction(ActivationFunction f) {
        this.f = f;
    }

    public ActivationFunction getActivationFunction() {
        return this.f;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        CategoricalResults cr = new CategoricalResults(this.outputSize);
        Vec x = this.feedForward(data.getNumericalValues());
        x.mutableSubtract(this.f.min() + this.targetBump);
        for (int i = 0; i < x.length(); ++i) {
            cr.setProb(i, Math.max(x.get(i), 0.0));
        }
        cr.normalize();
        return cr;
    }

    @Override
    public double regress(DataPoint data) {
        Vec x = this.feedForward(data.getNumericalValues());
        double val = x.get(0);
        val = (val - this.f.min() - this.targetBump) / this.targetMultiplier + this.targetMin;
        return val;
    }

    @Override
    public void train(ClassificationDataSet dataSet, boolean parallel) {
        this.train(dataSet);
    }

    @Override
    public void train(ClassificationDataSet dataSet) {
        this.inputSize = dataSet.getNumNumericalVars();
        this.outputSize = dataSet.getClassSize();
        Random rand = RandomUtil.getRandom();
        this.setUp(rand);
        this.trainNN(dataSet);
    }

    @Override
    public void train(RegressionDataSet dataSet, boolean parallel) {
        this.train(dataSet);
    }

    @Override
    public void train(RegressionDataSet dataSet) {
        this.targetMax = Double.NEGATIVE_INFINITY;
        this.targetMin = Double.POSITIVE_INFINITY;
        for (int i = 0; i < dataSet.size(); ++i) {
            double val = dataSet.getTargetValue(i);
            this.targetMax = Math.max(this.targetMax, val);
            this.targetMin = Math.min(this.targetMin, val);
        }
        this.targetMultiplier = (this.f.max() - this.targetBump - (this.f.min() + this.targetBump)) / (this.targetMax - this.targetMin);
        this.inputSize = dataSet.getNumNumericalVars();
        this.outputSize = 1;
        Random rand = RandomUtil.getRandom();
        this.setUp(rand);
        this.trainNN(dataSet);
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    @Override
    public BackPropagationNet clone() {
        return new BackPropagationNet(this);
    }

    private void setUp(Random rand) {
        this.Ws = new ArrayList<Matrix>(this.npl.length);
        this.bs = new ArrayList<Vec>(this.npl.length);
        DenseMatrix W = new DenseMatrix(this.npl[0], this.inputSize);
        DenseVector b = new DenseVector(W.rows());
        this.initializeWeights(W, rand);
        this.initializeWeights(b, W.cols(), rand);
        this.Ws.add(W);
        this.bs.add(b);
        for (int i = 1; i < this.npl.length; ++i) {
            W = new DenseMatrix(this.npl[i], this.npl[i - 1]);
            b = new DenseVector(W.rows());
            this.initializeWeights(W, rand);
            this.initializeWeights(b, W.cols(), rand);
            this.Ws.add(W);
            this.bs.add(b);
        }
        W = new DenseMatrix(this.outputSize, this.npl[this.npl.length - 1]);
        b = new DenseVector(W.rows());
        this.initializeWeights(W, rand);
        this.initializeWeights(b, W.cols(), rand);
        this.Ws.add(W);
        this.bs.add(b);
    }

    private double computeOutputDelta(DataSet dataSet, int idx, Vec delta_out, Vec a_i, Vec d_i) {
        double error = 0.0;
        if (dataSet instanceof ClassificationDataSet) {
            ClassificationDataSet cds = (ClassificationDataSet)dataSet;
            int ct = cds.getDataPointCategory(idx);
            for (int i = 0; i < this.outputSize; ++i) {
                if (i == ct) {
                    delta_out.set(i, this.f.max() - this.targetBump);
                    continue;
                }
                delta_out.set(i, this.f.min() + this.targetBump);
            }
            for (int j = 0; j < delta_out.length(); ++j) {
                double val = delta_out.get(j);
                error += Math.pow(val - a_i.get(j), 2.0);
                val = -(val - a_i.get(j)) * d_i.get(j);
                delta_out.set(j, val);
            }
        } else if (dataSet instanceof RegressionDataSet) {
            RegressionDataSet rds = (RegressionDataSet)dataSet;
            double val = rds.getTargetValue(idx);
            val = this.f.min() + this.targetBump + this.targetMultiplier * (val - this.targetMin);
            error += Math.pow(val - a_i.get(0), 2.0);
            delta_out.set(0, -(val - a_i.get(0)) * d_i.get(0));
        } else {
            throw new RuntimeException("BUG: please report");
        }
        return error;
    }

    private void feedForward(Vec input, List<Vec> activations, List<Vec> derivatives) {
        Vec x = input;
        for (int i = 0; i < this.Ws.size(); ++i) {
            Matrix W_i = this.Ws.get(i);
            Vec b_i = this.bs.get(i);
            Vec a_i = activations.get(i);
            a_i.zeroOut();
            W_i.multiply(x, 1.0, a_i);
            a_i.mutableAdd(b_i);
            a_i.applyFunction(this.f);
            Vec d_i = derivatives.get(i);
            a_i.copyTo(d_i);
            d_i.applyFunction(this.f.getD());
            x = a_i;
        }
    }

    private Vec feedForward(Vec input) {
        Vec x = input;
        for (int i = 0; i < this.Ws.size(); ++i) {
            Matrix W_i = this.Ws.get(i);
            Vec b_i = this.bs.get(i);
            Vec a_i = W_i.multiply(x);
            a_i.mutableAdd(b_i);
            a_i.applyFunction(this.f);
            x = a_i;
        }
        return x;
    }

    private void initializeWeights(Matrix W, Random rand) {
        for (int i = 0; i < W.rows(); ++i) {
            for (int j = 0; j < W.cols(); ++j) {
                W.set(i, j, this.weightInitialization.getWeight(W.cols(), W.rows(), this.initialLearningRate, rand));
            }
        }
    }

    private void initializeWeights(Vec b, int inputSize, Random rand) {
        for (int i = 0; i < b.length(); ++i) {
            b.set(i, this.weightInitialization.getWeight(inputSize, b.length(), this.initialLearningRate, rand));
        }
    }

    @Override
    public List<Parameter> getParameters() {
        ArrayList<Parameter> params = new ArrayList<Parameter>(Parameter.getParamsFromMethods(this));
        int i = 0;
        while (i < this.npl.length) {
            int ii;
            if (this.npl[ii = i++] < 1) {
                throw new ArithmeticException("There must be a poistive number of hidden neurons in each layer");
            }
            params.add(new IntParameter(){
                private static final long serialVersionUID = -827784019950722754L;

                @Override
                public int getValue() {
                    return BackPropagationNet.this.npl[ii];
                }

                @Override
                public boolean setValue(int val) {
                    if (val <= 0) {
                        return false;
                    }
                    ((BackPropagationNet)BackPropagationNet.this).npl[ii] = val;
                    return true;
                }

                @Override
                public String getASCIIName() {
                    return "Neurons for Hidden Layer " + ii;
                }
            });
        }
        params.add(new ObjectParameter<ActivationFunction>(){
            private static final long serialVersionUID = 6871130865935243583L;

            @Override
            public ActivationFunction getObject() {
                return BackPropagationNet.this.getActivationFunction();
            }

            @Override
            public boolean setObject(ActivationFunction obj) {
                BackPropagationNet.this.setActivationFunction(obj);
                return true;
            }

            @Override
            public List parameterOptions() {
                return Arrays.asList(logitActiv, tanhActiv, softsignActiv);
            }

            @Override
            public String getASCIIName() {
                return "Activation Function";
            }
        });
        return Collections.unmodifiableList(params);
    }

    @Override
    public Parameter getParameter(String paramName) {
        return Parameter.toParameterMap(this.getParameters()).get(paramName);
    }

    public static abstract class ActivationFunction
    implements Function1D,
    Serializable {
        private static final long serialVersionUID = 8002040194215453918L;

        public abstract double response(double var1);

        public abstract double min();

        public abstract double max();

        public abstract Function1D getD();

        @Override
        public double f(double x) {
            return this.response(x);
        }
    }

    public static enum WeightInitialization {
        UNIFORM{

            @Override
            public double getWeight(int inputSize, int layerSize, double eta, Random rand) {
                return rand.nextDouble() * 1.4 - 0.7;
            }
        }
        ,
        GUASSIAN{

            @Override
            public double getWeight(int inputSize, int layerSize, double eta, Random rand) {
                return Normal.invcdf(rand.nextDouble(), 0.0, Math.pow(inputSize, -0.5));
            }
        }
        ,
        TANH_NORMALIZED_INITIALIZATION{

            @Override
            public double getWeight(int inputSize, int layerSize, double eta, Random rand) {
                double cnst = Math.sqrt(6.0 / (double)(inputSize + layerSize));
                return rand.nextDouble() * cnst * 2.0 - cnst;
            }
        };


        public abstract double getWeight(int var1, int var2, double var3, Random var5);
    }
}

