/*
 * Decompiled with CFR 0.152.
 */
package jsat.regression;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import jsat.DataSet;
import jsat.classifiers.DataPoint;
import jsat.distributions.Distribution;
import jsat.distributions.LogUniform;
import jsat.distributions.kernels.KernelTrick;
import jsat.distributions.kernels.RBFKernel;
import jsat.linear.CholeskyDecomposition;
import jsat.linear.DenseMatrix;
import jsat.linear.Vec;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.utils.concurrent.ParallelUtils;

public class KernelRidgeRegression
implements Regressor,
Parameterized {
    private static final long serialVersionUID = 6275333785663250072L;
    private double lambda;
    @Parameter.ParameterHolder
    private KernelTrick k;
    private List<Vec> vecs;
    private double[] alphas;

    public KernelRidgeRegression() {
        this(1.0E-6, new RBFKernel());
    }

    public KernelRidgeRegression(double lambda, KernelTrick kernel) {
        this.setLambda(lambda);
        this.setKernel(kernel);
    }

    protected KernelRidgeRegression(KernelRidgeRegression toCopy) {
        this(toCopy.lambda, toCopy.getKernel().clone());
        if (toCopy.alphas != null) {
            this.alphas = Arrays.copyOf(toCopy.alphas, toCopy.alphas.length);
        }
        if (toCopy.vecs != null) {
            this.vecs = new ArrayList<Vec>(toCopy.vecs);
        }
    }

    public static Distribution guessLambda(DataSet d) {
        return new LogUniform(1.0E-7, 0.01);
    }

    public void setLambda(double lambda) {
        if (Double.isNaN(lambda) || Double.isInfinite(lambda) || lambda <= 0.0) {
            throw new IllegalArgumentException("lambda must be a positive constant, not " + lambda);
        }
        this.lambda = lambda;
    }

    public double getLambda() {
        return this.lambda;
    }

    public void setKernel(KernelTrick k) {
        this.k = k;
    }

    public KernelTrick getKernel() {
        return this.k;
    }

    @Override
    public double regress(DataPoint data) {
        Vec x = data.getNumericalValues();
        double score = 0.0;
        for (int i = 0; i < this.alphas.length; ++i) {
            score += this.alphas[i] * this.k.eval(this.vecs.get(i), x);
        }
        return score;
    }

    @Override
    public void train(RegressionDataSet dataSet, boolean parallel) {
        int N = dataSet.size();
        this.vecs = new ArrayList<Vec>(N);
        Vec Y = dataSet.getTargetValues();
        for (int i2 = 0; i2 < N; ++i2) {
            this.vecs.add(dataSet.getDataPoint(i2).getNumericalValues());
        }
        DenseMatrix K = new DenseMatrix(N, N);
        ParallelUtils.run(parallel, N, i -> {
            K.set(i, i, this.k.eval(this.vecs.get(i), this.vecs.get(i)) + this.lambda);
            for (int j = i + 1; j < N; ++j) {
                double K_ij = this.k.eval(this.vecs.get(i), this.vecs.get(j));
                K.set(i, j, K_ij);
                K.set(j, i, K_ij);
            }
        });
        CholeskyDecomposition cd = parallel ? new CholeskyDecomposition(K, ParallelUtils.CACHED_THREAD_POOL) : new CholeskyDecomposition(K);
        Vec alphaTmp = cd.solve(Y);
        this.alphas = alphaTmp.arrayCopy();
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    @Override
    public KernelRidgeRegression clone() {
        return new KernelRidgeRegression(this);
    }
}

