/*
 * Decompiled with CFR 0.152.
 */
package jsat.clustering;

import java.util.Arrays;
import java.util.List;
import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.DataSet;
import jsat.clustering.Clusterer;
import jsat.distributions.empirical.kernelfunc.GaussKF;
import jsat.distributions.empirical.kernelfunc.KernelFunction;
import jsat.distributions.multivariate.MetricKDE;
import jsat.distributions.multivariate.MultivariateKDE;
import jsat.exceptions.FailedToFitException;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.linear.VecPaired;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.utils.SystemInfo;
import jsat.utils.concurrent.ParallelUtils;

public class MeanShift
implements Clusterer {
    private static final long serialVersionUID = 4061491342362690455L;
    public static final int DefaultMaxIterations = 1000;
    public static final double DefaultScaleBandwidthFactor = 1.0;
    private MultivariateKDE mkde;
    private int maxIterations = 1000;
    private double scaleBandwidthFactor = 1.0;

    public MeanShift() {
        this(new EuclideanDistance());
    }

    public MeanShift(DistanceMetric dm) {
        this(new MetricKDE(GaussKF.getInstance(), dm));
    }

    public MeanShift(MultivariateKDE mkde) {
        this.mkde = mkde;
    }

    public MeanShift(MeanShift toCopy) {
        this.mkde = toCopy.mkde.clone();
        this.maxIterations = toCopy.maxIterations;
        this.scaleBandwidthFactor = toCopy.scaleBandwidthFactor;
    }

    public void setMaxIterations(int maxIterations) {
        if (maxIterations <= 0) {
            throw new ArithmeticException("Invalid iteration count, " + maxIterations);
        }
        this.maxIterations = maxIterations;
    }

    public int getMaxIterations() {
        return this.maxIterations;
    }

    public void setScaleBandwidthFactor(double scaleBandwidthFactor) {
        if (Double.isNaN(scaleBandwidthFactor) || Double.isInfinite(scaleBandwidthFactor)) {
            throw new ArithmeticException("Invalid scale factor, " + scaleBandwidthFactor);
        }
        this.scaleBandwidthFactor = scaleBandwidthFactor;
    }

    public double getScaleBandwidthFactor() {
        return this.scaleBandwidthFactor;
    }

    @Override
    public int[] cluster(DataSet dataSet, boolean parallel, int[] designations) {
        try {
            if (designations == null || designations.length < dataSet.size()) {
                designations = new int[dataSet.size()];
            }
            boolean[] converged = new boolean[dataSet.size()];
            Arrays.fill(converged, false);
            KernelFunction k = this.mkde.getKernelFunction();
            this.mkde.setUsingData(dataSet, parallel);
            this.mkde.scaleBandwidth(this.scaleBandwidthFactor);
            Vec[] xit = new Vec[converged.length];
            for (int i = 0; i < xit.length; ++i) {
                xit[i] = dataSet.getDataPoint(i).getNumericalValues().clone();
            }
            this.mainLoop(converged, xit, designations, k, parallel);
            this.assignmentStep(converged, xit, designations);
            return designations;
        }
        catch (InterruptedException ex) {
            Logger.getLogger(MeanShift.class.getName()).log(Level.SEVERE, null, ex);
            throw new FailedToFitException(ex);
        }
        catch (BrokenBarrierException ex) {
            Logger.getLogger(MeanShift.class.getName()).log(Level.SEVERE, null, ex);
            throw new FailedToFitException(ex);
        }
    }

    private void assignmentStep(boolean[] converged, Vec[] xit, int[] designations) {
        int curClusterID = 0;
        boolean progress = true;
        while (progress) {
            int basePos;
            progress = false;
            for (basePos = 0; basePos < converged.length && !converged[basePos]; ++basePos) {
            }
            for (int i = basePos; i < converged.length; ++i) {
                if (!converged[i] || designations[i] == -1) continue;
                progress = true;
                if (!(Math.abs(xit[basePos].pNormDist(2.0, xit[i])) < 0.001)) continue;
                converged[i] = false;
                designations[i] = curClusterID;
            }
            ++curClusterID;
        }
    }

    private void mainLoop(boolean[] converged, Vec[] xit, int[] designations, KernelFunction k, boolean parallel) throws InterruptedException, BrokenBarrierException {
        AtomicBoolean progress = new AtomicBoolean(true);
        int count = 0;
        CyclicBarrier barrier = new CyclicBarrier(SystemInfo.LogicalCores + 1);
        ThreadLocal<Vec> localScratch = ThreadLocal.withInitial(() -> new DenseVector(xit[0].length()));
        while (progress.get() && count++ < this.maxIterations) {
            progress.set(false);
            ParallelUtils.run(parallel, converged.length, i -> {
                if (converged[i]) {
                    return;
                }
                progress.lazySet(true);
                this.convergenceStep(xit, i, converged, designations, (Vec)localScratch.get(), k);
            });
        }
        Arrays.fill(converged, true);
    }

    private void convergenceStep(Vec[] xit, int i, boolean[] converged, int[] designations, Vec scratch, KernelFunction k) {
        double denom = 0.0;
        Vec xCur = xit[i];
        List<? extends VecPaired<VecPaired<Vec, Integer>, Double>> contrib = this.mkde.getNearbyRaw(xCur);
        if (contrib.size() == 1) {
            converged[i] = true;
            designations[i] = -1;
        } else {
            scratch.zeroOut();
            for (VecPaired<VecPaired<Vec, Integer>, Double> vecPaired : contrib) {
                double g = -k.kPrime(vecPaired.getPair());
                denom += g;
                scratch.mutableAdd(g, vecPaired);
            }
            scratch.mutableDivide(denom);
            if (Math.abs(scratch.pNormDist(2.0, xCur)) < 1.0E-5) {
                converged[i] = true;
            }
            scratch.copyTo(xCur);
        }
    }

    @Override
    public MeanShift clone() {
        return new MeanShift(this);
    }
}

