/*
 * Decompiled with CFR 0.152.
 */
package jsat.regression;

import jsat.classifiers.DataPoint;
import jsat.linear.DenseMatrix;
import jsat.linear.DenseVector;
import jsat.linear.LUPDecomposition;
import jsat.linear.Matrix;
import jsat.linear.SingularValueDecomposition;
import jsat.linear.Vec;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.utils.concurrent.ParallelUtils;

public class OrdinaryKriging
implements Regressor,
Parameterized {
    private static final long serialVersionUID = -5774553215322383751L;
    private Variogram vari;
    private Vec X;
    private RegressionDataSet dataSet;
    private double errorSqrd;
    private double nugget;
    public static final double DEFAULT_NUGGET = 0.1;
    public static final double DEFAULT_ERROR = 0.1;

    public OrdinaryKriging(Variogram vari, double error, double nugget) {
        this.vari = vari;
        this.setMeasurementError(error);
        this.nugget = nugget;
    }

    public OrdinaryKriging(Variogram vari, double error) {
        this(vari, error, 0.1);
    }

    public OrdinaryKriging(Variogram vari) {
        this(vari, 0.1);
    }

    public OrdinaryKriging() {
        this(new PowVariogram());
    }

    @Override
    public double regress(DataPoint data) {
        Vec x = data.getNumericalValues();
        int npt = this.X.length() - 1;
        double[] distVals = new double[npt + 1];
        for (int i = 0; i < npt; ++i) {
            distVals[i] = this.vari.val(x.pNormDist(2.0, this.dataSet.getDataPoint(i).getNumericalValues()));
        }
        distVals[npt] = 1.0;
        return this.X.dot(DenseVector.toDenseVec(distVals));
    }

    @Override
    public void train(RegressionDataSet dataSet, boolean parallel) {
        this.dataSet = dataSet;
        int N = dataSet.size();
        DenseVector Y = new DenseVector(N + 1);
        DenseMatrix V = new DenseMatrix(N + 1, N + 1);
        this.vari.train(dataSet, this.nugget);
        this.setUpVectorMatrix(N, dataSet, V, Y, parallel);
        for (int i = 0; i < N; ++i) {
            V.increment(i, i, -this.errorSqrd);
        }
        LUPDecomposition lup = parallel ? new LUPDecomposition(V, ParallelUtils.CACHED_THREAD_POOL) : new LUPDecomposition(V);
        this.X = lup.solve(Y);
        if (Double.isNaN(lup.det()) || Math.abs(lup.det()) < 1.0E-5) {
            SingularValueDecomposition svd = new SingularValueDecomposition(V);
            this.X = svd.solve(Y);
        }
    }

    private void setUpVectorMatrix(int N, RegressionDataSet dataSet, Matrix V, Vec Y, boolean parallel) {
        ParallelUtils.run(parallel, N, i -> {
            DataPoint dpi = dataSet.getDataPoint(i);
            Vec xi = dpi.getNumericalValues();
            for (int j = 0; j < N; ++j) {
                Vec xj = dataSet.getDataPoint(j).getNumericalValues();
                double val = this.vari.val(xi.pNormDist(2.0, xj));
                V.set(i, j, val);
                V.set(j, i, val);
            }
            V.set(i, N, 1.0);
            V.set(N, i, 1.0);
            Y.set(i, dataSet.getTargetValue(i));
        });
        V.set(N, N, 0.0);
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    @Override
    public OrdinaryKriging clone() {
        OrdinaryKriging clone = new OrdinaryKriging(this.vari.clone());
        clone.setMeasurementError(this.getMeasurementError());
        clone.setNugget(this.getNugget());
        if (this.X != null) {
            clone.X = this.X.clone();
        }
        if (this.dataSet != null) {
            clone.dataSet = this.dataSet;
        }
        return clone;
    }

    public void setMeasurementError(double error) {
        this.errorSqrd = error * error;
    }

    public double getMeasurementError() {
        return Math.sqrt(this.errorSqrd);
    }

    public void setNugget(double nugget) {
        if (nugget < 0.0 || Double.isNaN(nugget) || Double.isInfinite(nugget)) {
            throw new ArithmeticException("Nugget must be a positive value");
        }
        this.nugget = nugget;
    }

    public double getNugget() {
        return this.nugget;
    }

    public static class PowVariogram
    implements Variogram {
        private double alpha;
        private double beta;

        public PowVariogram() {
            this(1.5);
        }

        public PowVariogram(double beta) {
            this.beta = beta;
        }

        @Override
        public void train(RegressionDataSet dataSet, double nugget) {
            int npt = dataSet.size();
            double num = 0.0;
            double denom = 0.0;
            double nugSqrd = nugget * nugget;
            for (int i = 0; i < npt; ++i) {
                Vec xi = dataSet.getDataPoint(i).getNumericalValues();
                double yi = dataSet.getTargetValue(i);
                for (int j = i + 1; j < npt; ++j) {
                    Vec xj = dataSet.getDataPoint(j).getNumericalValues();
                    double yj = dataSet.getTargetValue(j);
                    double rb = Math.pow(xi.pNormDist(2.0, xj), this.beta);
                    num += rb * (0.5 * Math.pow(yi - yj, 2.0) - nugSqrd);
                    denom += rb * rb;
                }
            }
            this.alpha = num / denom;
        }

        @Override
        public double val(double r) {
            return this.alpha * Math.pow(r, this.beta);
        }

        @Override
        public Variogram clone() {
            PowVariogram clone = new PowVariogram(this.beta);
            clone.alpha = this.alpha;
            return clone;
        }
    }

    public static interface Variogram
    extends Cloneable {
        public void train(RegressionDataSet var1, double var2);

        public double val(double var1);

        public Variogram clone();
    }
}

