/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.neuralnetwork;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ExecutorService;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.neuralnetwork.SGDNetworkTrainer;
import jsat.classifiers.neuralnetwork.activations.ActivationLayer;
import jsat.classifiers.neuralnetwork.activations.ReLU;
import jsat.classifiers.neuralnetwork.activations.SoftmaxLayer;
import jsat.classifiers.neuralnetwork.initializers.ConstantInit;
import jsat.classifiers.neuralnetwork.initializers.GaussianNormalInit;
import jsat.classifiers.neuralnetwork.regularizers.Max2NormRegularizer;
import jsat.linear.SparseVector;
import jsat.linear.Vec;
import jsat.math.optimization.stochastic.AdaDelta;
import jsat.parameters.Parameterized;
import jsat.utils.IntList;
import jsat.utils.ListUtils;
import jsat.utils.concurrent.ParallelUtils;

public class DReDNetSimple
implements Classifier,
Parameterized {
    private static final long serialVersionUID = -342281027279571332L;
    private SGDNetworkTrainer network;
    private int[] hiddenSizes;
    private int batchSize = 256;
    private int epochs = 100;

    public DReDNetSimple() {
        this(1024, 1024);
    }

    public DReDNetSimple(int ... hiddenLayerSizes) {
        this.setHiddenSizes(hiddenLayerSizes);
    }

    public void setHiddenSizes(int[] hiddenSizes) {
        for (int i = 0; i < hiddenSizes.length; ++i) {
            if (hiddenSizes[i] > 0) continue;
            throw new IllegalArgumentException("Hidden layer " + i + " must contain a positive number of neurons, not " + hiddenSizes[i]);
        }
        this.hiddenSizes = Arrays.copyOf(hiddenSizes, hiddenSizes.length);
    }

    public int[] getHiddenSizes() {
        return this.hiddenSizes;
    }

    public void setBatchSize(int batchSize) {
        this.batchSize = batchSize;
    }

    public int getBatchSize() {
        return this.batchSize;
    }

    public void setEpochs(int epochs) {
        if (epochs <= 0) {
            throw new IllegalArgumentException("Number of epochs must be positive");
        }
        this.epochs = epochs;
    }

    public int getEpochs() {
        return this.epochs;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        Vec x = data.getNumericalValues();
        Vec y = this.network.feedfoward(x);
        return new CategoricalResults(y.arrayCopy());
    }

    @Override
    public void train(ClassificationDataSet dataSet, boolean parallel) {
        this.setup(dataSet);
        List<Vec> X = dataSet.getDataVectors();
        ArrayList<SparseVector> Y = new ArrayList<SparseVector>(dataSet.size());
        for (int i = 0; i < dataSet.size(); ++i) {
            SparseVector sv = new SparseVector(dataSet.getClassSize(), 1);
            sv.set(dataSet.getDataPointCategory(i), 1.0);
            Y.add(sv);
        }
        IntList randOrder = new IntList(X.size());
        ListUtils.addRange(randOrder, 0, X.size(), 1);
        ArrayList<Vec> Xmini = new ArrayList<Vec>(this.batchSize);
        ArrayList<Vec> Ymini = new ArrayList<Vec>(this.batchSize);
        ExecutorService threadPool = ParallelUtils.getNewExecutor(parallel);
        for (int epoch = 0; epoch < this.epochs; ++epoch) {
            long start = System.currentTimeMillis();
            double epochError = 0.0;
            Collections.shuffle(randOrder);
            for (int i = 0; i < X.size(); i += this.batchSize) {
                int to = Math.min(i + this.batchSize, X.size());
                Xmini.clear();
                Ymini.clear();
                for (int j = i; j < to; ++j) {
                    Xmini.add(X.get(j));
                    Ymini.add((Vec)Y.get(j));
                }
                double localErr = parallel ? this.network.updateMiniBatch(Xmini, Ymini, threadPool) : this.network.updateMiniBatch(Xmini, Ymini);
                epochError += localErr;
            }
            long l = System.currentTimeMillis();
        }
        this.network.finishUpdating();
    }

    private void setup(ClassificationDataSet dataSet) {
        this.network = new SGDNetworkTrainer();
        int[] sizes = new int[this.hiddenSizes.length + 2];
        sizes[0] = dataSet.getNumNumericalVars();
        for (int i = 0; i < this.hiddenSizes.length; ++i) {
            sizes[i + 1] = this.hiddenSizes[i];
        }
        sizes[sizes.length - 1] = dataSet.getClassSize();
        this.network.setLayerSizes(sizes);
        ArrayList<ActivationLayer> activations = new ArrayList<ActivationLayer>(this.hiddenSizes.length + 2);
        for (int size : this.hiddenSizes) {
            activations.add(new ReLU());
        }
        activations.add(new SoftmaxLayer());
        this.network.setLayersActivation(activations);
        this.network.setRegularizer(new Max2NormRegularizer(25.0));
        this.network.setWeightInit(new GaussianNormalInit(0.01));
        this.network.setBiasInit(new ConstantInit(0.1));
        this.network.setEta(1.0);
        this.network.setGradientUpdater(new AdaDelta());
        this.network.setup();
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    @Override
    public DReDNetSimple clone() {
        DReDNetSimple clone = new DReDNetSimple(this.hiddenSizes);
        if (this.network != null) {
            clone.network = this.network.clone();
        }
        clone.batchSize = this.batchSize;
        clone.epochs = this.epochs;
        return clone;
    }
}

