/*
 * Decompiled with CFR 0.152.
 */
package jsat.regression;

import jsat.SingleWeightVectorModel;
import jsat.classifiers.DataPoint;
import jsat.classifiers.DataPointPair;
import jsat.linear.DenseMatrix;
import jsat.linear.DenseVector;
import jsat.linear.Matrix;
import jsat.linear.QRDecomposition;
import jsat.linear.Vec;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.utils.concurrent.ParallelUtils;

public class MultipleLinearRegression
implements Regressor,
SingleWeightVectorModel {
    private static final long serialVersionUID = 7694194181910565061L;
    private Vec B;
    private double a;
    private boolean useWeights = false;

    public MultipleLinearRegression() {
        this(true);
    }

    public MultipleLinearRegression(boolean useWeights) {
        this.useWeights = useWeights;
    }

    @Override
    public double regress(DataPoint data) {
        return this.B.dot(data.getNumericalValues()) + this.a;
    }

    @Override
    public void train(RegressionDataSet dataSet, boolean parallel) {
        if (dataSet.getNumCategoricalVars() > 0) {
            throw new RuntimeException("Multiple Linear Regression only works with numerical values");
        }
        int sda = dataSet.size();
        DenseMatrix X = new DenseMatrix(dataSet.size(), dataSet.getNumNumericalVars() + 1);
        DenseVector Y = new DenseVector(dataSet.size());
        for (int i = 0; i < dataSet.size(); ++i) {
            DataPointPair<Double> dpp = dataSet.getDataPointPair(i);
            Y.set(i, dpp.getPair());
            X.set(i, 0, 1.0);
            Vec vals = dpp.getVector();
            for (int j = 0; j < vals.length(); ++j) {
                X.set(i, j + 1, vals.get(j));
            }
        }
        if (this.useWeights) {
            DenseVector weights = new DenseVector(dataSet.size());
            for (int i = 0; i < dataSet.size(); ++i) {
                ((Vec)weights).set(i, Math.sqrt(dataSet.getWeight(i)));
            }
            Matrix.diagMult(weights, X);
            Y.mutablePairwiseMultiply(weights);
        }
        Matrix[] QR = parallel ? X.qr(ParallelUtils.CACHED_THREAD_POOL) : X.qr();
        QRDecomposition qrDecomp = new QRDecomposition(QR[0], QR[1]);
        Vec tmp = qrDecomp.solve(Y);
        this.a = tmp.get(0);
        this.B = new DenseVector(dataSet.getNumNumericalVars());
        for (int i = 1; i < tmp.length(); ++i) {
            this.B.set(i - 1, tmp.get(i));
        }
    }

    @Override
    public boolean supportsWeightedData() {
        return this.useWeights;
    }

    @Override
    public Vec getRawWeight() {
        return this.B;
    }

    @Override
    public double getBias() {
        return this.a;
    }

    @Override
    public Vec getRawWeight(int index) {
        if (index < 1) {
            return this.getRawWeight();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override
    public double getBias(int index) {
        if (index < 1) {
            return this.getBias();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override
    public int numWeightsVecs() {
        return 1;
    }

    @Override
    public MultipleLinearRegression clone() {
        MultipleLinearRegression copy = new MultipleLinearRegression();
        if (this.B != null) {
            copy.B = this.B.clone();
        }
        copy.a = this.a;
        return copy;
    }
}

