/*
 * Decompiled with CFR 0.152.
 */
package jsat.regression;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.DataPoint;
import jsat.distributions.kernels.KernelTrick;
import jsat.linear.DenseMatrix;
import jsat.linear.DenseVector;
import jsat.linear.Matrix;
import jsat.linear.SubMatrix;
import jsat.linear.Vec;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.UpdateableRegressor;
import jsat.utils.DoubleList;
import jsat.utils.IntList;
import jsat.utils.ListUtils;

public class KernelRLS
implements UpdateableRegressor,
Parameterized {
    private static final long serialVersionUID = -7292074388953854317L;
    @Parameter.ParameterHolder
    private KernelTrick k;
    private double errorTolerance;
    private List<Vec> vecs;
    private List<Double> kernelAccel;
    private Matrix K;
    private Matrix InvK;
    private Matrix P;
    private Matrix KExpanded;
    private Matrix InvKExpanded;
    private Matrix PExpanded;
    private double[] alphaExpanded;

    public KernelRLS(KernelTrick k, double errorTolerance) {
        this.k = k;
        this.setErrorTolerance(errorTolerance);
    }

    protected KernelRLS(KernelRLS toCopy) {
        this.k = toCopy.k.clone();
        this.errorTolerance = toCopy.errorTolerance;
        if (toCopy.vecs != null) {
            this.vecs = new ArrayList<Vec>(toCopy.vecs.size());
            for (Vec vec : toCopy.vecs) {
                this.vecs.add(vec.clone());
            }
        }
        if (toCopy.KExpanded != null) {
            this.KExpanded = toCopy.KExpanded.clone();
            this.K = new SubMatrix(this.KExpanded, 0, 0, this.vecs.size(), this.vecs.size());
        }
        if (toCopy.InvKExpanded != null) {
            this.InvKExpanded = toCopy.InvKExpanded.clone();
            this.InvK = new SubMatrix(this.InvKExpanded, 0, 0, this.vecs.size(), this.vecs.size());
        }
        if (toCopy.PExpanded != null) {
            this.PExpanded = toCopy.PExpanded.clone();
            this.P = new SubMatrix(this.PExpanded, 0, 0, this.vecs.size(), this.vecs.size());
        }
        if (toCopy.alphaExpanded != null) {
            this.alphaExpanded = Arrays.copyOf(toCopy.alphaExpanded, toCopy.alphaExpanded.length);
        }
    }

    public void setErrorTolerance(double v) {
        if (Double.isNaN(v) || Double.isInfinite(v) || v <= 0.0) {
            throw new IllegalArgumentException("The error tolerance must be a positive constant, not " + v);
        }
        this.errorTolerance = v;
    }

    public double getErrorTolerance() {
        return this.errorTolerance;
    }

    public int getModelSize() {
        if (this.vecs == null) {
            return 0;
        }
        return this.vecs.size();
    }

    public void finalizeModel() {
        this.alphaExpanded = Arrays.copyOf(this.alphaExpanded, this.vecs.size());
        this.PExpanded = null;
        this.P = null;
        this.InvKExpanded = null;
        this.InvK = null;
        this.KExpanded = null;
        this.K = null;
    }

    @Override
    public double regress(DataPoint data) {
        Vec y = data.getNumericalValues();
        return this.k.evalSum(this.vecs, this.kernelAccel, this.alphaExpanded, y, 0, this.vecs.size());
    }

    @Override
    public void train(RegressionDataSet dataSet, boolean parallel) {
        this.train(dataSet);
    }

    @Override
    public void train(RegressionDataSet dataSet) {
        this.setUp(dataSet.getCategories(), dataSet.getNumNumericalVars());
        IntList randOrder = new IntList(dataSet.size());
        ListUtils.addRange(randOrder, 0, dataSet.size(), 1);
        Iterator iterator = randOrder.iterator();
        while (iterator.hasNext()) {
            int i = (Integer)iterator.next();
            this.update(dataSet.getDataPoint(i), dataSet.getWeight(i), dataSet.getTargetValue(i));
        }
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    @Override
    public KernelRLS clone() {
        return new KernelRLS(this);
    }

    @Override
    public void setUp(CategoricalData[] categoricalAttributes, int numericAttributes) {
        this.vecs = new ArrayList<Vec>();
        this.kernelAccel = this.k.supportsAcceleration() ? new DoubleList() : null;
        this.K = null;
        this.InvK = null;
        this.P = null;
        this.KExpanded = new DenseMatrix(100, 100);
        this.InvKExpanded = new DenseMatrix(100, 100);
        this.PExpanded = new DenseMatrix(100, 100);
        this.alphaExpanded = new double[100];
    }

    @Override
    public void update(DataPoint dataPoint, double weight, double y_t) {
        Vec x_t = dataPoint.getNumericalValues();
        List<Double> qi = this.k.getQueryInfo(x_t);
        double k_tt = this.k.eval(0, 0, Arrays.asList(x_t), qi);
        if (this.K == null) {
            this.K = new SubMatrix(this.KExpanded, 0, 0, 1, 1);
            this.K.set(0, 0, k_tt);
            this.InvK = new SubMatrix(this.InvKExpanded, 0, 0, 1, 1);
            this.InvK.set(0, 0, 1.0 / k_tt);
            this.P = new SubMatrix(this.PExpanded, 0, 0, 1, 1);
            this.P.set(0, 0, 1.0);
            this.alphaExpanded[0] = y_t / k_tt;
            this.vecs.add(x_t);
            if (this.kernelAccel != null) {
                this.kernelAccel.addAll(qi);
            }
            return;
        }
        DenseVector kxt = new DenseVector(this.K.rows());
        for (int i = 0; i < kxt.length(); ++i) {
            kxt.set(i, this.k.eval(i, x_t, qi, this.vecs, this.kernelAccel));
        }
        Vec alphas_t = this.InvK.multiply(kxt);
        double delta_t = k_tt - alphas_t.dot(kxt);
        int size = this.K.rows();
        double alphaConst = kxt.dot(new DenseVector(this.alphaExpanded, 0, size));
        if (delta_t > this.errorTolerance) {
            int i;
            this.vecs.add(x_t);
            if (this.kernelAccel != null) {
                this.kernelAccel.addAll(qi);
            }
            if (size == this.KExpanded.rows()) {
                this.KExpanded.changeSize(size * 2, size * 2);
                this.InvKExpanded.changeSize(size * 2, size * 2);
                this.PExpanded.changeSize(size * 2, size * 2);
                this.alphaExpanded = Arrays.copyOf(this.alphaExpanded, size * 2);
            }
            Matrix.OuterProductUpdate(this.InvK, alphas_t, alphas_t, 1.0 / delta_t);
            this.K = new SubMatrix(this.KExpanded, 0, 0, size + 1, size + 1);
            this.InvK = new SubMatrix(this.InvKExpanded, 0, 0, size + 1, size + 1);
            this.P = new SubMatrix(this.PExpanded, 0, 0, size + 1, size + 1);
            for (i = 0; i < size; ++i) {
                this.K.set(size, i, kxt.get(i));
                this.K.set(i, size, kxt.get(i));
                this.InvK.set(size, i, -alphas_t.get(i) / delta_t);
                this.InvK.set(i, size, -alphas_t.get(i) / delta_t);
            }
            this.K.set(size, size, k_tt);
            this.InvK.set(size, size, 1.0 / delta_t);
            this.P.set(size, size, 1.0);
            for (i = 0; i < size; ++i) {
                int n = i;
                this.alphaExpanded[n] = this.alphaExpanded[n] - alphas_t.get(i) * (y_t - alphaConst) / delta_t;
            }
            this.alphaExpanded[size] = (y_t - alphaConst) / delta_t;
        } else {
            Vec q_t = this.P.multiply(alphas_t);
            q_t.mutableDivide(1.0 + alphas_t.dot(q_t));
            Matrix.OuterProductUpdate(this.P, q_t, alphas_t.multiply(this.P), -1.0);
            Vec InvKqt = this.InvK.multiply(q_t);
            for (int i = 0; i < size; ++i) {
                int n = i;
                this.alphaExpanded[n] = this.alphaExpanded[n] + InvKqt.get(i) * (y_t - alphaConst);
            }
        }
    }
}

