/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.linear;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.DoubleAdder;
import jsat.DataSet;
import jsat.SimpleWeightVectorModel;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.WarmClassifier;
import jsat.distributions.Distribution;
import jsat.distributions.LogUniform;
import jsat.exceptions.FailedToFitException;
import jsat.linear.ConcatenatedVec;
import jsat.linear.DenseVector;
import jsat.linear.SubVector;
import jsat.linear.Vec;
import jsat.lossfunctions.LossC;
import jsat.lossfunctions.LossFunc;
import jsat.lossfunctions.LossMC;
import jsat.lossfunctions.LossR;
import jsat.lossfunctions.SoftmaxLoss;
import jsat.math.Function;
import jsat.math.FunctionVec;
import jsat.math.optimization.LBFGS;
import jsat.math.optimization.Optimizer;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.regression.WarmRegressor;
import jsat.utils.concurrent.ParallelUtils;

public class LinearBatch
implements Classifier,
Regressor,
Parameterized,
SimpleWeightVectorModel,
WarmClassifier,
WarmRegressor {
    private static final long serialVersionUID = -446156124954287580L;
    private Vec[] ws;
    private double[] bs;
    private LossFunc loss;
    private double lambda0;
    private Optimizer optimizer;
    private double tolerance;
    private boolean useBiasTerm = true;

    public LinearBatch() {
        this(new SoftmaxLoss(), 1.0E-6);
    }

    public LinearBatch(LossFunc loss, double lambda0) {
        this(loss, lambda0, 0.001);
    }

    public LinearBatch(LossFunc loss, double lambda0, double tolerance) {
        this(loss, lambda0, tolerance, null);
    }

    public LinearBatch(LossFunc loss, double lambda0, double tolerance, Optimizer optimizer) {
        this.setLoss(loss);
        this.setLambda0(lambda0);
        this.setOptimizer(optimizer);
        this.setTolerance(tolerance);
    }

    public LinearBatch(LinearBatch toCopy) {
        this(toCopy.loss.clone(), toCopy.lambda0, toCopy.tolerance, toCopy.optimizer == null ? null : toCopy.optimizer.clone());
        if (toCopy.ws != null) {
            this.ws = new Vec[toCopy.ws.length];
            for (int i = 0; i < toCopy.ws.length; ++i) {
                this.ws[i] = toCopy.ws[i].clone();
            }
        }
        if (toCopy.bs != null) {
            this.bs = Arrays.copyOf(toCopy.bs, toCopy.bs.length);
        }
    }

    public void setUseBiasTerm(boolean useBiasTerm) {
        this.useBiasTerm = useBiasTerm;
    }

    public boolean isUseBiasTerm() {
        return this.useBiasTerm;
    }

    public void setLambda0(double lambda0) {
        if (lambda0 < 0.0 || Double.isNaN(lambda0) || Double.isInfinite(lambda0)) {
            throw new IllegalArgumentException("Lambda0 must be non-negative, not " + lambda0);
        }
        this.lambda0 = lambda0;
    }

    public double getLambda0() {
        return this.lambda0;
    }

    public void setLoss(LossFunc loss) {
        this.loss = loss;
    }

    public LossFunc getLoss() {
        return this.loss;
    }

    public void setOptimizer(Optimizer optimizer) {
        this.optimizer = optimizer;
    }

    public Optimizer getOptimizer() {
        return this.optimizer;
    }

    public void setTolerance(double tolerance) {
        if (tolerance < 0.0 || Double.isNaN(tolerance) || Double.isInfinite(tolerance)) {
            throw new IllegalArgumentException("Tolerance must be a non-negative constant, not " + tolerance);
        }
        this.tolerance = tolerance;
    }

    public double getTolerance() {
        return this.tolerance;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        Vec x = data.getNumericalValues();
        if (this.ws.length == 1) {
            return ((LossC)this.loss).getClassification(this.ws[0].dot(x) + this.bs[0]);
        }
        DenseVector pred = new DenseVector(this.ws.length);
        for (int i = 0; i < this.ws.length; ++i) {
            ((Vec)pred).set(i, this.ws[i].dot(x) + this.bs[i]);
        }
        ((LossMC)this.loss).process(pred, pred);
        return ((LossMC)this.loss).getClassification(pred);
    }

    @Override
    public double regress(DataPoint data) {
        Vec x = data.getNumericalValues();
        return ((LossR)this.loss).getRegression(this.ws[0].dot(x) + this.bs[0]);
    }

    @Override
    public void train(ClassificationDataSet D2, boolean parallel) {
        this.train(D2, null, parallel);
    }

    @Override
    public void train(ClassificationDataSet D2, Classifier warmSolution, boolean parallel) {
        if (D2.getNumNumericalVars() <= 0) {
            throw new FailedToFitException("LinearBath requires numeric features to work");
        }
        if (!(this.loss instanceof LossC)) {
            throw new FailedToFitException("Loss function " + this.loss.getClass().getSimpleName() + " does not support classification");
        }
        if (D2.getClassSize() > 2) {
            if (!(this.loss instanceof LossMC)) {
                throw new FailedToFitException("Loss function " + this.loss.getClass().getSimpleName() + " does not support multi-class classification");
            }
            this.ws = new Vec[D2.getClassSize()];
            this.bs = new double[this.ws.length];
        } else {
            this.ws = new Vec[1];
            this.bs = new double[1];
        }
        for (int i = 0; i < this.ws.length; ++i) {
            this.ws[i] = new DenseVector(D2.getNumNumericalVars());
        }
        Optimizer optimizerToUse = this.optimizer == null ? new LBFGS(10) : this.optimizer.clone();
        this.doWarmStartIfNotNull(warmSolution);
        ExecutorService threadPool = ParallelUtils.getNewExecutor(parallel);
        if (this.ws.length == 1) {
            if (this.useBiasTerm) {
                VecWithBias w_tmp = new VecWithBias(this.ws[0], this.bs);
                optimizerToUse.optimize(this.tolerance, w_tmp, w_tmp, new LossFunction(D2, this.loss), new GradFunction(D2, this.loss), parallel);
            } else {
                optimizerToUse.optimize(this.tolerance, this.ws[0], this.ws[0], new LossFunction(D2, this.loss), new GradFunction(D2, this.loss), parallel);
            }
        } else {
            ConcatenatedVec wAll;
            LossMC lossMC = (LossMC)this.loss;
            if (this.useBiasTerm) {
                ArrayList<Vec> vecs = new ArrayList<Vec>(Arrays.asList(this.ws));
                vecs.add(DenseVector.toDenseVec(this.bs));
                wAll = new ConcatenatedVec(vecs);
            } else {
                wAll = new ConcatenatedVec(Arrays.asList(this.ws));
            }
            optimizerToUse.optimize(this.tolerance, wAll, new DenseVector(wAll), new LossMCFunction(D2, lossMC), new GradMCFunction(D2, lossMC), parallel);
        }
        threadPool.shutdownNow();
    }

    private void doWarmStartIfNotNull(Object warmSolution) throws FailedToFitException {
        if (warmSolution != null) {
            if (warmSolution instanceof SimpleWeightVectorModel) {
                SimpleWeightVectorModel warm = (SimpleWeightVectorModel)warmSolution;
                if (warm.numWeightsVecs() != this.ws.length) {
                    throw new FailedToFitException("Warm solution has " + warm.numWeightsVecs() + " weight vectors instead of " + this.ws.length);
                }
                for (int i = 0; i < this.ws.length; ++i) {
                    warm.getRawWeight(i).copyTo(this.ws[i]);
                    if (!this.useBiasTerm) continue;
                    this.bs[i] = warm.getBias(i);
                }
            } else {
                throw new FailedToFitException("Can not warm warm from " + warmSolution.getClass().getCanonicalName());
            }
        }
    }

    @Override
    public void train(RegressionDataSet D2, boolean parallel) {
        this.train(D2, this, parallel);
    }

    @Override
    public void train(RegressionDataSet dataSet, Regressor warmSolution) {
        this.train(dataSet, warmSolution, false);
    }

    @Override
    public void train(RegressionDataSet D2, Regressor warmSolution, boolean parallel) {
        if (D2.getNumNumericalVars() <= 0) {
            throw new FailedToFitException("LinearBath requires numeric features to work");
        }
        if (!(this.loss instanceof LossR)) {
            throw new FailedToFitException("Loss function " + this.loss.getClass().getSimpleName() + " does not regression");
        }
        this.ws = new Vec[]{new DenseVector(D2.getNumNumericalVars())};
        this.bs = new double[1];
        Optimizer optimizerToUse = this.optimizer == null ? new LBFGS(10) : this.optimizer.clone();
        this.doWarmStartIfNotNull(warmSolution);
        ExecutorService threadPool = ParallelUtils.getNewExecutor(parallel);
        if (this.useBiasTerm) {
            VecWithBias w_tmp = new VecWithBias(this.ws[0], this.bs);
            optimizerToUse.optimize(this.tolerance, w_tmp, w_tmp, new LossFunction(D2, this.loss), new GradFunction(D2, this.loss), parallel);
        } else {
            optimizerToUse.optimize(this.tolerance, this.ws[0], this.ws[0], new LossFunction(D2, this.loss), new GradFunction(D2, this.loss), parallel);
        }
        threadPool.shutdownNow();
    }

    private static double getTargetY(DataSet D2, int i) {
        double y = D2 instanceof ClassificationDataSet ? (double)(((ClassificationDataSet)D2).getDataPointCategory(i) * 2 - 1) : ((RegressionDataSet)D2).getTargetValue(i);
        return y;
    }

    @Override
    public boolean warmFromSameDataOnly() {
        return false;
    }

    @Override
    public Vec getRawWeight(int index) {
        return this.ws[index];
    }

    @Override
    public double getBias(int index) {
        return this.bs[index];
    }

    @Override
    public int numWeightsVecs() {
        return this.ws.length;
    }

    @Override
    public boolean supportsWeightedData() {
        return true;
    }

    @Override
    public LinearBatch clone() {
        return new LinearBatch(this);
    }

    public static Distribution guessLambda0(DataSet d) {
        return new LogUniform(1.0E-7, 0.01);
    }

    private class GradMCFunction
    implements FunctionVec {
        private final ClassificationDataSet D;
        private final LossMC loss;
        private ThreadLocal<Vec> tempVecs;

        public GradMCFunction(ClassificationDataSet D2, LossMC loss) {
            this.D = D2;
            this.loss = loss;
        }

        @Override
        public Vec f(Vec w, Vec s, boolean parllel) {
            if (s == null) {
                s = w.clone();
            }
            s.zeroOut();
            ThreadLocal<Vec> tl_s = ThreadLocal.withInitial(s::clone);
            DenseVector pred = new DenseVector(this.D.getClassSize());
            int subWSize = (w.length() - (LinearBatch.this.useBiasTerm ? LinearBatch.this.bs.length : 0)) / this.D.getClassSize();
            DoubleAdder weightSum = new DoubleAdder();
            ParallelUtils.run(parllel, this.D.size(), (start, end) -> {
                Vec s_l = (Vec)tl_s.get();
                Vec pred_local = pred.clone();
                for (int i = start; i < end; ++i) {
                    DataPoint dp = this.D.getDataPoint(i);
                    Vec x = dp.getNumericalValues();
                    for (int k = 0; k < pred_local.length(); ++k) {
                        pred_local.set(k, new SubVector(k * subWSize, subWSize, w).dot(x));
                    }
                    if (LinearBatch.this.useBiasTerm) {
                        pred_local.mutableAdd(new SubVector(w.length() - LinearBatch.this.bs.length, LinearBatch.this.bs.length, w));
                    }
                    this.loss.process(pred_local, pred_local);
                    int y = this.D.getDataPointCategory(i);
                    this.loss.deriv(pred_local, pred_local, y);
                    for (int k = 0; k < pred_local.length(); ++k) {
                        new SubVector(k * subWSize, subWSize, s_l).mutableAdd(pred_local.get(k) * this.D.getWeight(i), x);
                    }
                    weightSum.add(this.D.getWeight(i));
                }
                return s_l;
            }, (a, b) -> a.add((Vec)b)).copyTo(s);
            s.mutableDivide(weightSum.sum());
            if (LinearBatch.this.lambda0 > 0.0) {
                s.mutableSubtract(LinearBatch.this.lambda0, w);
            }
            return s;
        }
    }

    public class LossMCFunction
    implements Function {
        private static final long serialVersionUID = -861700500356609563L;
        private final ClassificationDataSet D;
        private final LossMC loss;

        public LossMCFunction(ClassificationDataSet D2, LossMC loss) {
            this.D = D2;
            this.loss = loss;
        }

        @Override
        public double f(Vec w, boolean parallel) {
            DoubleAdder sum = new DoubleAdder();
            DenseVector pred = new DenseVector(this.D.getClassSize());
            int subWSize = (w.length() - (LinearBatch.this.useBiasTerm ? LinearBatch.this.bs.length : 0)) / this.D.getClassSize();
            DoubleAdder weightSum = new DoubleAdder();
            ParallelUtils.run(parallel, this.D.size(), (start, end) -> {
                Vec pred_local = pred.clone();
                for (int i = start; i < end; ++i) {
                    DataPoint dp = this.D.getDataPoint(i);
                    Vec x = dp.getNumericalValues();
                    for (int k = 0; k < pred_local.length(); ++k) {
                        pred_local.set(k, new SubVector(k * subWSize, subWSize, w).dot(x));
                    }
                    if (LinearBatch.this.useBiasTerm) {
                        pred_local.mutableAdd(new SubVector(w.length() - LinearBatch.this.bs.length, LinearBatch.this.bs.length, w));
                    }
                    this.loss.process(pred_local, pred_local);
                    int y = this.D.getDataPointCategory(i);
                    sum.add(this.loss.getLoss(pred_local, y) * this.D.getWeight(i));
                    weightSum.add(this.D.getWeight(i));
                }
            });
            if (LinearBatch.this.lambda0 > 0.0) {
                return sum.sum() / weightSum.sum() + LinearBatch.this.lambda0 * w.dot(w);
            }
            return sum.sum();
        }
    }

    public class GradFunction
    implements FunctionVec {
        private final DataSet D;
        private final LossFunc loss;
        private ThreadLocal<Vec> tempVecs;

        public GradFunction(DataSet D2, LossFunc loss) {
            this.D = D2;
            this.loss = loss;
        }

        @Override
        public Vec f(Vec w, Vec s, boolean parallel) {
            if (s == null) {
                s = w.clone();
            }
            s.zeroOut();
            DoubleAdder weightSum = new DoubleAdder();
            ThreadLocal<Vec> tl_s = ThreadLocal.withInitial(s::clone);
            ParallelUtils.run(parallel, this.D.size(), (start, end) -> {
                Vec s_l = (Vec)tl_s.get();
                for (int i = start; i < end; ++i) {
                    DataPoint dp = this.D.getDataPoint(i);
                    Vec x = dp.getNumericalValues();
                    double y = LinearBatch.getTargetY(this.D, i);
                    s_l.mutableAdd(this.loss.getDeriv(w.dot(x), y) * this.D.getWeight(i), x);
                    weightSum.add(this.D.getWeight(i));
                }
                return s_l;
            }, (a, b) -> a.add((Vec)b)).copyTo(s);
            s.mutableDivide(weightSum.sum());
            if (LinearBatch.this.lambda0 > 0.0) {
                s.mutableSubtract(LinearBatch.this.lambda0, w);
            }
            return s;
        }
    }

    public class LossFunction
    implements Function {
        private static final long serialVersionUID = -576682206943283356L;
        private final DataSet D;
        private final LossFunc loss;

        public LossFunction(DataSet D2, LossFunc loss) {
            this.D = D2;
            this.loss = loss;
        }

        @Override
        public double f(Vec w, boolean parallel) {
            DoubleAdder sum = new DoubleAdder();
            DoubleAdder weightSum = new DoubleAdder();
            ParallelUtils.run(parallel, this.D.size(), (start, end) -> {
                for (int i = start; i < end; ++i) {
                    DataPoint dp = this.D.getDataPoint(i);
                    Vec x = dp.getNumericalValues();
                    double y = LinearBatch.getTargetY(this.D, i);
                    sum.add(this.loss.getLoss(w.dot(x), y) * this.D.getWeight(i));
                    weightSum.add(this.D.getWeight(i));
                }
            });
            if (LinearBatch.this.lambda0 > 0.0) {
                return sum.sum() / weightSum.sum() + LinearBatch.this.lambda0 * w.dot(w);
            }
            return sum.sum() / weightSum.sum();
        }
    }

    private class VecWithBias
    extends Vec {
        public Vec w;
        public double[] b;

        public VecWithBias(Vec w, double[] b) {
            this.w = w;
            this.b = b;
        }

        @Override
        public double dot(Vec v) {
            if (v.length() == this.w.length()) {
                return this.w.dot(v) + this.b[0];
            }
            return super.dot(v);
        }

        @Override
        public void mutableAdd(double c, Vec b) {
            if (b.length() == this.w.length()) {
                this.w.mutableAdd(c, b);
                this.b[0] = this.b[0] + c;
            } else {
                super.mutableAdd(c, b);
            }
        }

        @Override
        public int length() {
            return this.w.length() + 1;
        }

        @Override
        public double get(int index) {
            if (index < this.w.length()) {
                return this.w.get(index);
            }
            if (index == this.w.length()) {
                return this.b[0];
            }
            throw new IndexOutOfBoundsException();
        }

        @Override
        public void set(int index, double val) {
            if (index < this.w.length()) {
                this.w.set(index, val);
            } else if (index == this.w.length()) {
                this.b[0] = val;
            } else {
                throw new IndexOutOfBoundsException();
            }
        }

        @Override
        public boolean isSparse() {
            return this.w.isSparse();
        }

        @Override
        public Vec clone() {
            return new VecWithBias(this.w.clone(), Arrays.copyOf(this.b, this.b.length));
        }

        @Override
        public void setLength(int length) {
            throw new UnsupportedOperationException("Not supported yet.");
        }
    }
}

