/*
 * Decompiled with CFR 0.152.
 */
package jsat.clustering.evaluation.intra;

import java.util.List;
import jsat.DataSet;
import jsat.SimpleDataSet;
import jsat.classifiers.DataPoint;
import jsat.clustering.evaluation.intra.IntraClusterEvaluation;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;

public class SumOfSqrdPairwiseDistances
implements IntraClusterEvaluation {
    private DistanceMetric dm;

    public SumOfSqrdPairwiseDistances() {
        this(new EuclideanDistance());
    }

    public SumOfSqrdPairwiseDistances(DistanceMetric dm) {
        this.dm = dm;
    }

    public SumOfSqrdPairwiseDistances(SumOfSqrdPairwiseDistances toCopy) {
        this(toCopy.dm.clone());
    }

    public void setDistanceMetric(DistanceMetric dm) {
        this.dm = dm;
    }

    public DistanceMetric getDistanceMetric() {
        return this.dm;
    }

    @Override
    public double evaluate(int[] designations, DataSet dataSet, int clusterID) {
        int N = 0;
        double sum = 0.0;
        List<Vec> X = dataSet.getDataVectors();
        List<Double> cache = this.dm.getAccelerationCache(X);
        if (this.dm instanceof EuclideanDistance) {
            DenseVector mean = new DenseVector(X.get(0).length());
            for (int i = 0; i < dataSet.size(); ++i) {
                if (designations[i] != clusterID) continue;
                mean.mutableAdd(X.get(i));
                ++N;
            }
            ((Vec)mean).mutableDivide((double)N + 1.0E-10);
            List<Double> qi = this.dm.getQueryInfo(mean);
            for (int i = 0; i < dataSet.size(); ++i) {
                if (designations[i] != clusterID) continue;
                sum += Math.pow(this.dm.dist(i, mean, qi, X, cache), 2.0);
            }
            return sum;
        }
        for (int i = 0; i < dataSet.size(); ++i) {
            if (designations[i] != clusterID) continue;
            ++N;
            for (int j = i + 1; j < dataSet.size(); ++j) {
                if (designations[j] != clusterID) continue;
                sum += 2.0 * Math.pow(this.dm.dist(i, j, X, cache), 2.0);
            }
        }
        return sum / (double)(N * 2);
    }

    @Override
    public double evaluate(List<DataPoint> dataPoints) {
        return this.evaluate(new int[dataPoints.size()], new SimpleDataSet(dataPoints), 0);
    }

    @Override
    public SumOfSqrdPairwiseDistances clone() {
        return new SumOfSqrdPairwiseDistances(this);
    }
}

