/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.neuralnetwork;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.classifiers.neuralnetwork.activations.ActivationLayer;
import jsat.classifiers.neuralnetwork.initializers.BiastInitializer;
import jsat.classifiers.neuralnetwork.initializers.WeightInitializer;
import jsat.classifiers.neuralnetwork.regularizers.Max2NormRegularizer;
import jsat.classifiers.neuralnetwork.regularizers.WeightRegularizer;
import jsat.linear.DenseMatrix;
import jsat.linear.DenseVector;
import jsat.linear.Matrix;
import jsat.linear.Vec;
import jsat.math.decayrates.DecayRate;
import jsat.math.decayrates.NoDecay;
import jsat.math.optimization.stochastic.GradientUpdater;
import jsat.math.optimization.stochastic.SimpleSGD;
import jsat.utils.SystemInfo;
import jsat.utils.random.RandomUtil;

public class SGDNetworkTrainer
implements Serializable {
    private static final long serialVersionUID = 5753653181230693131L;
    private int[] layerSizes;
    private double eta;
    private double p_i;
    private int p_i_intThresh;
    private double p_o;
    private int p_o_intThresh;
    private GradientUpdater updater = new SimpleSGD();
    private WeightRegularizer regularizer = new Max2NormRegularizer(15.0);
    private WeightInitializer weightInit;
    private BiastInitializer biasInit;
    private List<Matrix> W;
    private List<Matrix> W_deltas;
    private List<List<GradientUpdater>> W_updaters;
    private List<Vec> B;
    private List<Vec> B_deltas;
    private List<GradientUpdater> B_updaters;
    private List<ActivationLayer> layersActivation;
    private DecayRate etaDecay = new NoDecay();
    private int time;
    private Matrix[] activations;
    private Matrix[] unactivated;
    private Matrix[] deltas;

    public SGDNetworkTrainer() {
        this.setDropoutInput(0.2);
        this.setDropoutHidden(0.5);
    }

    public SGDNetworkTrainer(SGDNetworkTrainer toCopy) {
        this.layerSizes = Arrays.copyOf(toCopy.layerSizes, toCopy.layerSizes.length);
        this.eta = toCopy.eta;
        this.weightInit = toCopy.weightInit.clone();
        this.biasInit = toCopy.biasInit.clone();
        this.regularizer = toCopy.regularizer.clone();
        this.updater = toCopy.updater.clone();
        this.setDropoutInput(toCopy.getDropoutInput());
        this.setDropoutHidden(toCopy.getDropoutHidden());
        if (toCopy.W != null) {
            this.W = new ArrayList<Matrix>();
            for (Matrix matrix : toCopy.W) {
                this.W.add(matrix.clone());
            }
            this.B = new ArrayList<Vec>();
            for (Vec vec : toCopy.B) {
                this.B.add(vec.clone());
            }
        }
        if (toCopy.W_deltas != null) {
            this.W_deltas = new ArrayList<Matrix>();
            for (Matrix matrix : toCopy.W_deltas) {
                this.W_deltas.add(matrix.clone());
            }
            this.B_deltas = new ArrayList<Vec>();
            for (Vec vec : toCopy.B_deltas) {
                this.B_deltas.add(vec.clone());
            }
        }
        if (toCopy.W_updaters != null) {
            this.W_updaters = new ArrayList<List<GradientUpdater>>();
            for (List list : toCopy.W_updaters) {
                ArrayList<GradientUpdater> copyUpdaters = new ArrayList<GradientUpdater>(list.size());
                this.W_updaters.add(copyUpdaters);
                for (GradientUpdater item : list) {
                    copyUpdaters.add(item.clone());
                }
            }
            this.B_updaters = new ArrayList<GradientUpdater>(toCopy.B_updaters);
            for (GradientUpdater gradientUpdater : toCopy.B_updaters) {
                this.B_updaters.add(gradientUpdater.clone());
            }
        }
        this.layersActivation = new ArrayList<ActivationLayer>(toCopy.layersActivation.size());
        for (ActivationLayer activationLayer : toCopy.layersActivation) {
            this.layersActivation.add(activationLayer.clone());
        }
    }

    public void setDropoutInput(double p) {
        if (p < 0.0 || p >= 1.0 || Double.isNaN(p)) {
            throw new IllegalArgumentException("Dropout probability must be in [0,1) not " + p);
        }
        this.p_i = p;
        this.p_i_intThresh = (int)(4.294967295E9 * this.p_i + -2.147483648E9);
    }

    public double getDropoutInput() {
        return this.p_i;
    }

    public void setDropoutHidden(double p) {
        if (p < 0.0 || p >= 1.0 || Double.isNaN(p)) {
            throw new IllegalArgumentException("Dropout probability must be in [0,1) not " + p);
        }
        this.p_o = p;
        this.p_o_intThresh = (int)(4.294967295E9 * this.p_o + -2.147483648E9);
    }

    public double getDropoutHidden() {
        return this.p_o;
    }

    public void setEtaDecay(DecayRate etaDecay) {
        this.etaDecay = etaDecay;
    }

    public DecayRate getEtaDecay() {
        return this.etaDecay;
    }

    public void setEta(double eta) {
        if (eta <= 0.0 || Double.isNaN(eta) || Double.isInfinite(eta)) {
            throw new IllegalArgumentException("eta must be a positive constant, not " + eta);
        }
        this.eta = eta;
    }

    public double getEta() {
        return this.eta;
    }

    public void setRegularizer(WeightRegularizer regularizer) {
        this.regularizer = regularizer;
    }

    public WeightRegularizer getRegularizer() {
        return this.regularizer;
    }

    public void setLayerSizes(int ... layerSizes) {
        this.layerSizes = layerSizes;
    }

    public int[] getLayerSizes() {
        return this.layerSizes;
    }

    public void setLayersActivation(List<ActivationLayer> layersActivation) {
        this.layersActivation = layersActivation;
    }

    public void setGradientUpdater(GradientUpdater updater) {
        this.updater = updater;
    }

    public GradientUpdater getGradientUpdater() {
        return this.updater;
    }

    public void setWeightInit(WeightInitializer weightInit) {
        this.weightInit = weightInit;
    }

    public WeightInitializer getWeightInit() {
        return this.weightInit;
    }

    public void setBiasInit(BiastInitializer biasInit) {
        this.biasInit = biasInit;
    }

    public BiastInitializer getBiasInit() {
        return this.biasInit;
    }

    public void setup() {
        assert (this.layersActivation.size() == this.layerSizes.length - 1);
        this.W = new ArrayList<Matrix>(this.layersActivation.size());
        this.B = new ArrayList<Vec>(this.layersActivation.size());
        Random rand = RandomUtil.getRandom();
        for (int l = 1; l < this.layerSizes.length; ++l) {
            this.W.add(new DenseMatrix(this.layerSizes[l], this.layerSizes[l - 1]));
            this.weightInit.init(this.W.get(this.W.size() - 1), rand);
            this.B.add(new DenseVector(this.layerSizes[l]));
            this.biasInit.init(this.B.get(this.B.size() - 1), this.layerSizes[l - 1], rand);
        }
        this.time = 0;
        this.prepareForUpdating();
    }

    private void prepareForUpdating() {
        this.W_deltas = new ArrayList<Matrix>(this.layersActivation.size());
        this.W_updaters = new ArrayList<List<GradientUpdater>>(this.layersActivation.size());
        this.B_deltas = new ArrayList<Vec>(this.layersActivation.size());
        this.B_updaters = new ArrayList<GradientUpdater>(this.layersActivation.size());
        for (int l = 1; l < this.layerSizes.length; ++l) {
            this.W_deltas.add(new DenseMatrix(this.layerSizes[l], this.layerSizes[l - 1]));
            this.B_deltas.add(new DenseVector(this.layerSizes[l]));
            ArrayList<GradientUpdater> W_updaters_l = new ArrayList<GradientUpdater>(this.layerSizes[l]);
            for (int i = 0; i < this.layerSizes[l]; ++i) {
                GradientUpdater W_updater = this.updater.clone();
                W_updater.setup(this.layerSizes[l - 1]);
                W_updaters_l.add(W_updater);
            }
            this.W_updaters.add(W_updaters_l);
            this.B_updaters.add(this.updater.clone());
            this.B_updaters.get(this.B_updaters.size() - 1).setup(this.layerSizes[l]);
        }
        this.activations = new Matrix[this.layersActivation.size()];
        this.unactivated = new Matrix[this.layersActivation.size()];
        this.deltas = new Matrix[this.layersActivation.size()];
    }

    public void finishUpdating() {
        this.W_deltas = null;
        this.W_updaters = null;
        this.B_deltas = null;
        this.B_updaters = null;
        this.deltas = null;
        this.unactivated = null;
        this.activations = null;
        this.W.get(0).mutableMultiply(1.0 - this.p_i);
        this.B.get(0).mutableMultiply(1.0 - this.p_i);
        for (int i = 1; i < this.W.size(); ++i) {
            this.W.get(i).mutableMultiply(1.0 - this.p_o);
            this.B.get(i).mutableMultiply(1.0 - this.p_o);
        }
    }

    public double updateMiniBatch(List<Vec> x, List<Vec> y) {
        return this.updateMiniBatch(x, y, null);
    }

    public double updateMiniBatch(List<Vec> x, List<Vec> y, ExecutorService ex) {
        Random rand = RandomUtil.getRandom();
        for (Matrix w : this.W_deltas) {
            w.zeroOut();
        }
        for (Vec b : this.B_deltas) {
            b.zeroOut();
        }
        for (int i = 0; i < this.layersActivation.size(); ++i) {
            if (this.activations[i] == null || this.activations[i].cols() != x.size()) {
                this.activations[i] = new DenseMatrix(this.layerSizes[i + 1], x.size());
            }
            if (this.unactivated[i] == null || this.unactivated[i].cols() != x.size()) {
                this.unactivated[i] = new DenseMatrix(this.layerSizes[i + 1], x.size());
            }
            if (this.deltas[i] != null && this.deltas[i].cols() == x.size()) continue;
            this.deltas[i] = new DenseMatrix(this.layerSizes[i + 1], x.size());
        }
        DenseMatrix X = new DenseMatrix(this.layerSizes[0], x.size());
        for (int j = 0; j < x.size(); ++j) {
            x.get(j).copyTo(X.getColumnView(j));
        }
        if (this.p_i > 0.0) {
            SGDNetworkTrainer.applyDropout(X, this.p_i_intThresh, rand, ex);
        }
        double errorMade = 0.0;
        this.feedforward(X, this.activations, this.unactivated, ex, rand);
        errorMade = this.backpropagateError(this.deltas, this.activations, x, y, errorMade, ex, this.unactivated);
        this.accumulateUpdates(X, this.activations, this.deltas, ex, x);
        double eta_cur = this.etaDecay.rate(this.time++, this.eta);
        if (ex == null) {
            this.applyGradient(eta_cur);
        } else {
            this.applyGradient(eta_cur, ex);
        }
        return errorMade;
    }

    private void feedforward(Matrix X, Matrix[] activationsM, Matrix[] unactivatedM, ExecutorService ex, Random rand) {
        for (int l = 0; l < this.layersActivation.size(); ++l) {
            Matrix a_lprev = l == 0 ? X : activationsM[l - 1];
            Matrix a_l = activationsM[l];
            final Matrix z_l = unactivatedM[l];
            z_l.zeroOut();
            if (ex == null) {
                this.W.get(l).multiply(a_lprev, z_l);
            } else {
                this.W.get(l).multiply(a_lprev, z_l, ex);
            }
            final Vec B_l = this.B.get(l);
            if (ex == null) {
                for (int i = 0; i < z_l.rows(); ++i) {
                    double B_li = B_l.get(i);
                    for (int j = 0; j < z_l.cols(); ++j) {
                        z_l.increment(i, j, B_li);
                    }
                }
            } else {
                final CountDownLatch latch = new CountDownLatch(SystemInfo.LogicalCores);
                int id = 0;
                while (id < SystemInfo.LogicalCores) {
                    final int ID = id++;
                    ex.submit(new Runnable(){

                        @Override
                        public void run() {
                            for (int i = ID; i < z_l.rows(); i += SystemInfo.LogicalCores) {
                                double B_li = B_l.get(i);
                                for (int j = 0; j < z_l.cols(); ++j) {
                                    z_l.increment(i, j, B_li);
                                }
                            }
                            latch.countDown();
                        }
                    });
                }
                try {
                    latch.await();
                }
                catch (InterruptedException ex1) {
                    Logger.getLogger(SGDNetworkTrainer.class.getName()).log(Level.SEVERE, null, ex1);
                }
            }
            if (this.p_o > 0.0 && l != this.layersActivation.size() - 1) {
                SGDNetworkTrainer.applyDropout(z_l, this.p_o_intThresh, rand, ex);
            }
            this.layersActivation.get(l).activate(z_l, a_l, false);
        }
    }

    public Vec feedfoward(Vec x) {
        Vec a_lprev = x;
        for (int l = 0; l < this.layersActivation.size(); ++l) {
            DenseVector z_l = new DenseVector(this.layerSizes[l + 1]);
            z_l.zeroOut();
            this.W.get(l).multiply(a_lprev, 1.0, z_l);
            Vec B_l = this.B.get(l);
            z_l.mutableAdd(B_l);
            this.layersActivation.get(l).activate(z_l, z_l);
            a_lprev = z_l;
        }
        return a_lprev;
    }

    private double backpropagateError(Matrix[] deltasM, Matrix[] activationsM, List<Vec> x, List<Vec> y, double errorMade, ExecutorService ex, Matrix[] unactivatedM) {
        for (int l = this.layersActivation.size() - 1; l >= 0; --l) {
            Matrix delta_l = deltasM[l];
            if (l == this.layersActivation.size() - 1) {
                activationsM[l].copyTo(delta_l);
                for (int r = 0; r < x.size(); ++r) {
                    delta_l.getColumnView(r).mutableSubtract(y.get(r));
                    errorMade += delta_l.getColumnView(r).pNorm(2.0);
                }
                continue;
            }
            delta_l.zeroOut();
            if (ex == null) {
                this.W.get(l + 1).transposeMultiply(deltasM[l + 1], delta_l);
            } else {
                this.W.get(l + 1).transposeMultiply(deltasM[l + 1], delta_l, ex);
            }
            this.layersActivation.get(l).backprop(unactivatedM[l], activationsM[l], delta_l, delta_l, false);
        }
        return errorMade;
    }

    private void accumulateUpdates(Matrix X, Matrix[] activationsM, Matrix[] deltasM, ExecutorService ex, List<Vec> x) {
        final double invXsize = 1.0 / (double)x.size();
        for (int l = 0; l < this.layersActivation.size(); ++l) {
            Matrix a_lprev = l == 0 ? X : activationsM[l - 1];
            final Matrix delta_l = deltasM[l];
            if (ex == null) {
                delta_l.multiplyTranspose(a_lprev, this.W_deltas.get(l));
            } else {
                delta_l.multiplyTranspose(a_lprev, this.W_deltas.get(l), ex);
            }
            this.W_deltas.get(l).mutableMultiply(invXsize);
            final Vec B_delta_l = this.B_deltas.get(l);
            if (ex == null) {
                for (int i = 0; i < delta_l.rows(); ++i) {
                    double change = 0.0;
                    for (int j = 0; j < delta_l.cols(); ++j) {
                        change += delta_l.get(i, j);
                    }
                    B_delta_l.increment(i, change * invXsize);
                }
                continue;
            }
            final CountDownLatch latch = new CountDownLatch(Math.min(SystemInfo.LogicalCores, delta_l.rows()));
            int id = 0;
            while (id < SystemInfo.LogicalCores) {
                final int ID = id++;
                ex.submit(new Runnable(){

                    @Override
                    public void run() {
                        for (int i = ID; i < delta_l.rows(); i += SystemInfo.LogicalCores) {
                            double change = 0.0;
                            for (int j = 0; j < delta_l.cols(); ++j) {
                                change += delta_l.get(i, j);
                            }
                            B_delta_l.increment(i, change * invXsize);
                        }
                        latch.countDown();
                    }
                });
            }
            try {
                latch.await();
                continue;
            }
            catch (InterruptedException ex1) {
                Logger.getLogger(SGDNetworkTrainer.class.getName()).log(Level.SEVERE, null, ex1);
            }
        }
    }

    private void applyGradient(double eta_cur) {
        for (int l = 0; l < this.layersActivation.size(); ++l) {
            this.B_updaters.get(l).update(this.B.get(l), this.B_deltas.get(l), eta_cur);
            Matrix W_l = this.W.get(l);
            Matrix W_dl = this.W_deltas.get(l);
            for (int i = 0; i < W_l.rows(); ++i) {
                Vec W_li = W_l.getRowView(i);
                this.W_updaters.get(l).get(i).update(W_li, W_dl.getRowView(i), eta_cur);
            }
            this.regularizer.applyRegularization(W_l, this.B.get(l));
        }
    }

    private void applyGradient(final double eta_cur, ExecutorService ex) {
        ArrayList futures = new ArrayList();
        for (int l = 0; l < this.layersActivation.size(); ++l) {
            this.B_updaters.get(l).update(this.B.get(l), this.B_deltas.get(l), eta_cur);
            final Matrix matrix = this.W.get(l);
            final Matrix W_dl = this.W_deltas.get(l);
            final int L = l;
            int indx = 0;
            while (indx < matrix.rows()) {
                final int i = indx++;
                futures.add(ex.submit(new Runnable(){

                    @Override
                    public void run() {
                        Vec W_li = matrix.getRowView(i);
                        ((GradientUpdater)((List)SGDNetworkTrainer.this.W_updaters.get(L)).get(i)).update(W_li, W_dl.getRowView(i), eta_cur);
                        ((Vec)SGDNetworkTrainer.this.B.get(L)).set(i, SGDNetworkTrainer.this.regularizer.applyRegularizationToRow(W_li, ((Vec)SGDNetworkTrainer.this.B.get(L)).get(i)));
                    }
                }));
            }
        }
        try {
            for (Future future : futures) {
                future.get();
            }
        }
        catch (InterruptedException interruptedException) {
        }
        catch (ExecutionException executionException) {
            // empty catch block
        }
    }

    private static void applyDropout(final Matrix X, final int randThresh, final Random rand, ExecutorService ex) {
        if (ex == null) {
            for (int i = 0; i < X.rows(); ++i) {
                for (int j = 0; j < X.cols(); ++j) {
                    if (rand.nextInt() >= randThresh) continue;
                    X.set(i, j, 0.0);
                }
            }
        } else {
            final CountDownLatch latch = new CountDownLatch(SystemInfo.LogicalCores);
            int id = 0;
            while (id < SystemInfo.LogicalCores) {
                final int ID = id++;
                ex.submit(new Runnable(){

                    @Override
                    public void run() {
                        for (int i = ID; i < X.rows(); i += SystemInfo.LogicalCores) {
                            for (int j = 0; j < X.cols(); ++j) {
                                if (rand.nextInt() >= randThresh) continue;
                                X.set(i, j, 0.0);
                            }
                        }
                        latch.countDown();
                    }
                });
            }
            try {
                latch.await();
            }
            catch (InterruptedException ex1) {
                Logger.getLogger(SGDNetworkTrainer.class.getName()).log(Level.SEVERE, null, ex1);
            }
        }
    }

    protected SGDNetworkTrainer clone() {
        return new SGDNetworkTrainer(this);
    }
}

