/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.atomic.DoubleAdder;
import jsat.DataSet;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.exceptions.FailedToFitException;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.linear.distancemetrics.TrainableDistanceMetric;
import jsat.utils.DoubleList;
import jsat.utils.concurrent.ParallelUtils;

public class Rocchio
implements Classifier {
    private static final long serialVersionUID = 889524967453326516L;
    private List<Vec> rocVecs;
    private final DistanceMetric dm;
    private List<Double> rocCache;

    public Rocchio() {
        this(new EuclideanDistance());
    }

    public Rocchio(DistanceMetric dm) {
        this.dm = dm;
        this.rocVecs = null;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        int i;
        CategoricalResults cr = new CategoricalResults(this.rocVecs.size());
        double sum = 0.0;
        Vec target = data.getNumericalValues();
        List<Double> qi = this.dm.getQueryInfo(target);
        for (i = 0; i < this.rocVecs.size(); ++i) {
            double distance = this.dm.dist(i, target, qi, this.rocVecs, this.rocCache);
            sum += distance;
            cr.setProb(i, distance);
        }
        for (i = 0; i < this.rocVecs.size(); ++i) {
            cr.setProb(i, 1.0 - cr.getProb(i) / sum);
        }
        return cr;
    }

    @Override
    public void train(ClassificationDataSet dataSet, boolean parallel) {
        if (dataSet.getNumCategoricalVars() != 0) {
            throw new FailedToFitException("Classifier requires all variables be numerical");
        }
        int C2 = dataSet.getClassSize();
        this.rocVecs = new ArrayList<Vec>(C2);
        TrainableDistanceMetric.trainIfNeeded(this.dm, (DataSet)dataSet, parallel);
        int d = dataSet.getNumNumericalVars();
        DoubleAdder totalWeight = new DoubleAdder();
        this.rocVecs = new ArrayList<Object>(Arrays.asList((Object[])ParallelUtils.run(parallel, dataSet.size(), (start, end) -> {
            int i;
            Vec[] local_roc = new Vec[C2];
            for (i = 0; i < C2; ++i) {
                local_roc[i] = new DenseVector(d);
            }
            for (i = start; i < end; ++i) {
                double w = dataSet.getWeight(i);
                local_roc[dataSet.getDataPointCategory(i)].mutableAdd(w, dataSet.getDataPoint(i).getNumericalValues());
                totalWeight.add(w);
            }
            return local_roc;
        }, (t, u) -> {
            for (int i = 0; i < ((Vec[])t).length; ++i) {
                t[i].mutableAdd(u[i]);
            }
            return t;
        })));
        double[] priors = dataSet.getPriors();
        for (int i = 0; i < C2; ++i) {
            this.rocVecs.get(i).mutableDivide(totalWeight.sum() * priors[i]);
        }
        this.rocCache = this.dm.getAccelerationCache(this.rocVecs, parallel);
    }

    @Override
    public boolean supportsWeightedData() {
        return true;
    }

    @Override
    public Rocchio clone() {
        Rocchio copy = new Rocchio(this.dm);
        if (this.rocVecs != null) {
            copy.rocVecs = new ArrayList<Vec>(this.rocVecs.size());
            for (Vec v : this.rocVecs) {
                copy.rocVecs.add(v.clone());
            }
            copy.rocCache = new DoubleList(this.rocCache);
        }
        return copy;
    }
}

