/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.knn;

import java.util.ArrayList;
import java.util.List;
import jsat.DataSet;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.distributions.Distribution;
import jsat.distributions.discrete.UniformDiscrete;
import jsat.distributions.empirical.kernelfunc.EpanechnikovKF;
import jsat.distributions.empirical.kernelfunc.KernelFunction;
import jsat.exceptions.UntrainedModelException;
import jsat.linear.Vec;
import jsat.linear.VecPaired;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.TrainableDistanceMetric;
import jsat.linear.vectorcollection.DefaultVectorCollection;
import jsat.linear.vectorcollection.VectorCollection;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;

public class LWL
implements Classifier,
Regressor,
Parameterized {
    private static final long serialVersionUID = 6942465758987345997L;
    private CategoricalData predicting;
    private Classifier classifier;
    private Regressor regressor;
    private int k;
    private DistanceMetric dm;
    private KernelFunction kf;
    private VectorCollection<VecPaired<Vec, Double>> vc;

    private LWL(LWL toCopy) {
        if (toCopy.predicting != null) {
            this.predicting = toCopy.predicting.clone();
        }
        if (toCopy.classifier != null) {
            this.setClassifier(toCopy.classifier);
        }
        if (toCopy.regressor != null) {
            this.setRegressor(toCopy.regressor);
        }
        this.setNeighbors(toCopy.k);
        this.setDistanceMetric(toCopy.dm.clone());
        this.setKernelFunction(toCopy.kf);
        if (toCopy.vc != null) {
            this.vc = toCopy.vc.clone();
        }
    }

    public LWL(Classifier classifier, int k, DistanceMetric dm) {
        this(classifier, k, dm, (KernelFunction)EpanechnikovKF.getInstance());
    }

    public LWL(Classifier classifier, int k, DistanceMetric dm, KernelFunction kf) {
        this(classifier, k, dm, kf, new DefaultVectorCollection<VecPaired<Vec, Double>>());
    }

    public LWL(Classifier classifier, int k, DistanceMetric dm, KernelFunction kf, VectorCollection<VecPaired<Vec, Double>> vcf) {
        this.setClassifier(classifier);
        this.setNeighbors(k);
        this.setDistanceMetric(dm);
        this.setKernelFunction(kf);
        this.vc = vcf;
    }

    public LWL(Regressor regressor, int k, DistanceMetric dm) {
        this(regressor, k, dm, (KernelFunction)EpanechnikovKF.getInstance());
    }

    public LWL(Regressor regressor, int k, DistanceMetric dm, KernelFunction kf) {
        this(regressor, k, dm, kf, new DefaultVectorCollection<VecPaired<Vec, Double>>());
    }

    public LWL(Regressor regressor, int k, DistanceMetric dm, KernelFunction kf, VectorCollection<VecPaired<Vec, Double>> vcf) {
        this.setRegressor(regressor);
        this.setNeighbors(k);
        this.setDistanceMetric(dm);
        this.setKernelFunction(kf);
        this.vc = vcf;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        if (this.classifier == null || this.vc == null) {
            throw new UntrainedModelException("Model has not been trained");
        }
        List<VecPaired<VecPaired<Vec, Double>, Double>> knn = this.vc.search(data.getNumericalValues(), this.k);
        ClassificationDataSet localSet = new ClassificationDataSet(knn.get(0).length(), new CategoricalData[0], this.predicting);
        double maxD = knn.get(knn.size() - 1).getPair();
        for (int i = 0; i < knn.size(); ++i) {
            VecPaired<VecPaired<Vec, Double>, Double> v = knn.get(i);
            localSet.addDataPoint(v, v.getVector().getPair().intValue(), this.kf.k(v.getPair() / maxD));
        }
        Classifier localClassifier = this.classifier.clone();
        localClassifier.train(localSet);
        return localClassifier.classify(data);
    }

    @Override
    public void train(ClassificationDataSet dataSet, boolean parallel) {
        List<VecPaired<Vec, Double>> trainList = this.getVecList(dataSet);
        TrainableDistanceMetric.trainIfNeeded(this.dm, (DataSet)dataSet, parallel);
        this.vc.build(parallel, trainList, this.dm);
        this.predicting = dataSet.getPredicting();
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    @Override
    public double regress(DataPoint data) {
        if (this.regressor == null || this.vc == null) {
            throw new UntrainedModelException("Model has not been trained");
        }
        List<VecPaired<VecPaired<Vec, Double>, Double>> knn = this.vc.search(data.getNumericalValues(), this.k);
        RegressionDataSet localSet = new RegressionDataSet(knn.get(0).length(), new CategoricalData[0]);
        double maxD = knn.get(knn.size() - 1).getPair();
        for (int i = 0; i < knn.size(); ++i) {
            VecPaired<VecPaired<Vec, Double>, Double> v = knn.get(i);
            localSet.addDataPoint(v, (double)v.getVector().getPair());
            localSet.setWeight(i, this.kf.k(v.getPair() / maxD));
        }
        Regressor localRegressor = this.regressor.clone();
        localRegressor.train(localSet);
        return localRegressor.regress(data);
    }

    @Override
    public void train(RegressionDataSet dataSet, boolean parallel) {
        List<VecPaired<Vec, Double>> trainList = this.getVecList(dataSet);
        TrainableDistanceMetric.trainIfNeeded(this.dm, (DataSet)dataSet, parallel);
        this.vc.build(parallel, trainList, this.dm);
    }

    @Override
    public LWL clone() {
        return new LWL(this);
    }

    private List<VecPaired<Vec, Double>> getVecList(ClassificationDataSet dataSet) {
        ArrayList<VecPaired<Vec, Double>> trainList = new ArrayList<VecPaired<Vec, Double>>(dataSet.size());
        for (int i = 0; i < dataSet.size(); ++i) {
            trainList.add(new VecPaired<Vec, Double>(dataSet.getDataPoint(i).getNumericalValues(), new Double(dataSet.getDataPointCategory(i))));
        }
        return trainList;
    }

    private List<VecPaired<Vec, Double>> getVecList(RegressionDataSet dataSet) {
        ArrayList<VecPaired<Vec, Double>> trainList = new ArrayList<VecPaired<Vec, Double>>(dataSet.size());
        for (int i = 0; i < dataSet.size(); ++i) {
            trainList.add(new VecPaired<Vec, Double>(dataSet.getDataPoint(i).getNumericalValues(), dataSet.getTargetValue(i)));
        }
        return trainList;
    }

    private void setClassifier(Classifier classifier) {
        this.classifier = classifier;
        if (classifier instanceof Regressor) {
            this.regressor = (Regressor)((Object)classifier);
        }
    }

    private void setRegressor(Regressor regressor) {
        this.regressor = regressor;
        if (regressor instanceof Classifier) {
            this.classifier = (Classifier)((Object)regressor);
        }
    }

    public void setNeighbors(int k) {
        if (k <= 1) {
            throw new RuntimeException("An average requires at least 2 neighbors to be taken into account");
        }
        this.k = k;
    }

    public int getNeighbors() {
        return this.k;
    }

    public void setDistanceMetric(DistanceMetric dm) {
        this.dm = dm;
    }

    public DistanceMetric getDistanceMetric() {
        return this.dm;
    }

    public void setKernelFunction(KernelFunction kf) {
        this.kf = kf;
    }

    public KernelFunction getKernelFunction() {
        return this.kf;
    }

    public static Distribution guessNeighbors(DataSet d) {
        return new UniformDiscrete(25, Math.min(200, d.size() / 5));
    }
}

