/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.imbalance;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.DataPointPair;
import jsat.classifiers.imbalance.SMOTE;
import jsat.exceptions.FailedToFitException;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.linear.vectorcollection.DefaultVectorCollection;
import jsat.utils.IntList;
import jsat.utils.ListUtils;
import jsat.utils.concurrent.ParallelUtils;
import jsat.utils.random.RandomUtil;

public class BorderlineSMOTE
extends SMOTE {
    private boolean majorityInterpolation;

    public BorderlineSMOTE(Classifier baseClassifier) {
        this(baseClassifier, false);
    }

    public BorderlineSMOTE(Classifier baseClassifier, boolean majorityInterpolation) {
        this(baseClassifier, (DistanceMetric)new EuclideanDistance(), majorityInterpolation);
    }

    public BorderlineSMOTE(Classifier baseClassifier, DistanceMetric dm, boolean majorityInterpolation) {
        this(baseClassifier, dm, 1.0, majorityInterpolation);
    }

    public BorderlineSMOTE(Classifier baseClassifier, DistanceMetric dm, double targetRatio, boolean majorityInterpolation) {
        this(baseClassifier, dm, 5, targetRatio, majorityInterpolation);
    }

    public BorderlineSMOTE(Classifier baseClassifier, DistanceMetric dm, int smoteNeighbors, double targetRatio, boolean majorityInterpolation) {
        super(baseClassifier, dm, smoteNeighbors, targetRatio);
        this.setMajorityInterpolation(majorityInterpolation);
    }

    public BorderlineSMOTE(BorderlineSMOTE toCopy) {
        super(toCopy);
        this.majorityInterpolation = toCopy.majorityInterpolation;
    }

    public void setMajorityInterpolation(boolean majorityInterpolation) {
        this.majorityInterpolation = majorityInterpolation;
    }

    public boolean isMajorityInterpolation() {
        return this.majorityInterpolation;
    }

    /*
     * WARNING - void declaration
     */
    @Override
    public void train(ClassificationDataSet dataSet, boolean parallel) {
        int i;
        if (dataSet.getNumCategoricalVars() != 0) {
            throw new FailedToFitException("SMOTE only works with numeric-only feature values");
        }
        List<Vec> vAll = dataSet.getDataVectors();
        IntList[] classIndex = new IntList[dataSet.getClassSize()];
        for (i = 0; i < classIndex.length; ++i) {
            classIndex[i] = new IntList();
        }
        for (i = 0; i < dataSet.size(); ++i) {
            classIndex[dataSet.getDataPointCategory(i)].add(i);
        }
        double[] priors = dataSet.getPriors();
        DenseVector ratios = DenseVector.toDenseVec(priors).clone();
        int majorityNum = (int)((double)dataSet.size() * ((Vec)ratios).max());
        ((Vec)ratios).mutableDivide(((Vec)ratios).max());
        ArrayList synthetics = new ArrayList();
        DefaultVectorCollection<Vec> VC_all = new DefaultVectorCollection<Vec>(this.dm, vAll, parallel);
        Iterator iterator = ListUtils.range(0, dataSet.getClassSize()).iterator();
        while (iterator.hasNext()) {
            void var19_25;
            int classID = (Integer)iterator.next();
            int samplesNeeded = (int)((double)majorityNum * this.targetRatio - (double)classIndex[classID].size());
            if (samplesNeeded <= 0) continue;
            ArrayList<Vec> V_id = new ArrayList<Vec>();
            Iterator iterator2 = classIndex[classID].iterator();
            while (iterator2.hasNext()) {
                int i2 = (Integer)iterator2.next();
                V_id.add(vAll.get(i2));
            }
            DefaultVectorCollection VC_id = new DefaultVectorCollection(this.dm, V_id, parallel);
            ArrayList<List<Integer>> allNeighbors = new ArrayList<List<Integer>>();
            ArrayList<List<Double>> allDistances = new ArrayList<List<Double>>();
            VC_all.search(V_id, this.smoteNeighbors + 1, allNeighbors, allDistances, parallel);
            ArrayList otherClassSamples = new ArrayList();
            if (this.majorityInterpolation) {
                for (List list : allNeighbors) {
                    otherClassSamples.add(new ArrayList(this.smoteNeighbors));
                }
            }
            IntList danger_id = new IntList();
            boolean bl = false;
            while (var19_25 < VC_id.size()) {
                int same_class = 0;
                List neighors_of_i = (List)allNeighbors.get((int)var19_25);
                for (int j = 1; j < this.smoteNeighbors + 1; ++j) {
                    if (classID == dataSet.getDataPointCategory((Integer)neighors_of_i.get(j))) {
                        ++same_class;
                        continue;
                    }
                    if (!this.majorityInterpolation) continue;
                    ((List)otherClassSamples.get((int)var19_25)).add(VC_all.get((Integer)neighors_of_i.get(j)));
                }
                double sOm = 1.0 - (double)same_class / (double)this.smoteNeighbors;
                if (0.5 <= sOm && sOm < 1.0) {
                    danger_id.add((int)var19_25);
                }
                ++var19_25;
            }
            ArrayList<List<Integer>> arrayList = new ArrayList<List<Integer>>();
            ArrayList<List<Double>> idDistances = new ArrayList<List<Double>>();
            VC_id.search(VC_id, this.smoteNeighbors + 1, arrayList, idDistances, parallel);
            ParallelUtils.run(parallel, samplesNeeded, (start, end) -> {
                Random rand = RandomUtil.getRandom();
                ArrayList<DataPoint> local_new = new ArrayList<DataPoint>();
                for (int i = start; i < end; ++i) {
                    Vec vec_nn;
                    boolean useOtherClass;
                    int sampleIndex = danger_id.isEmpty() ? i % V_id.size() : danger_id.getI(i % danger_id.size());
                    boolean bl = useOtherClass = rand.nextBoolean() && this.majorityInterpolation && !danger_id.isEmpty();
                    if (useOtherClass) {
                        List candidates = (List)otherClassSamples.get(sampleIndex);
                        vec_nn = (Vec)candidates.get(rand.nextInt(candidates.size()));
                    } else {
                        int nn = rand.nextInt(this.smoteNeighbors) + 1;
                        vec_nn = VC_id.get((Integer)((List)idNeighbors.get(sampleIndex)).get(nn));
                    }
                    double gap = rand.nextDouble();
                    if (useOtherClass) {
                        gap /= 2.0;
                    }
                    Vec newVal = ((Vec)V_id.get(sampleIndex)).clone();
                    newVal.mutableMultiply(gap + 1.0);
                    newVal.mutableAdd(gap, vec_nn);
                    local_new.add(new DataPoint(newVal));
                }
                List list = synthetics;
                synchronized (list) {
                    for (DataPoint v : local_new) {
                        synthetics.add(new DataPointPair<Integer>(v, classID));
                    }
                }
            });
        }
        ClassificationDataSet newDataSet = new ClassificationDataSet(ListUtils.mergedView(synthetics, dataSet.getAsDPPList()), dataSet.getPredicting());
        this.baseClassifier.train(newDataSet, parallel);
    }

    @Override
    public BorderlineSMOTE clone() {
        return new BorderlineSMOTE(this);
    }
}

