/*
 * Decompiled with CFR 0.152.
 */
package jsat.linear.vectorcollection;

import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveAction;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.vectorcollection.BaseCaseDT;
import jsat.linear.vectorcollection.IndexDistPair;
import jsat.linear.vectorcollection.IndexNode;
import jsat.linear.vectorcollection.IndexTuple;
import jsat.linear.vectorcollection.ScoreDT;
import jsat.linear.vectorcollection.ScoreDTLazy;
import jsat.linear.vectorcollection.VectorCollection;
import jsat.utils.BoundedSortedList;
import jsat.utils.DoubleList;
import jsat.utils.IndexTable;
import jsat.utils.IntList;
import jsat.utils.ListUtils;

public interface DualTree<V extends Vec>
extends VectorCollection<V> {
    public static final double COMP_SCORE = -1.0;

    public IndexNode getRoot();

    @Override
    public DualTree<V> clone();

    default public double dist(int self_index, int other_index, DualTree<V> other) {
        return this.getDistanceMetric().dist((Vec)this.get(self_index), (Vec)other.get(self_index));
    }

    @Override
    public void search(Vec var1, int var2, List<Integer> var3, List<Double> var4);

    @Override
    default public void search(VectorCollection<V> VC, int numNeighbors, List<List<Integer>> neighbors, List<List<Double>> distances, boolean parallel) {
        int i;
        if (!(VC instanceof DualTree)) {
            VectorCollection.super.search(VC, numNeighbors, neighbors, distances, parallel);
            return;
        }
        DualTree Q = (DualTree)VC;
        AbstractMap query_B_cache = parallel ? new ConcurrentHashMap(Q.size()) : new IdentityHashMap(Q.size());
        ArrayList allPriorities = new ArrayList();
        for (int i2 = 0; i2 < Q.size(); ++i2) {
            allPriorities.add(new BoundedSortedList(numNeighbors));
        }
        List<Double> this_cache = this.getAccelerationCache();
        List<Double> other_cache = Q.getAccelerationCache();
        int N_r = this.size();
        DoubleList wholeCache = this_cache == null ? null : new DoubleList(ListUtils.mergedView(this_cache, other_cache));
        ArrayList allVecs = new ArrayList(N_r + Q.size());
        for (i = 0; i < N_r; ++i) {
            allVecs.add(this.get(i));
        }
        for (i = 0; i < Q.size(); ++i) {
            allVecs.add(Q.get(i));
        }
        DistanceMetric dm = this.getDistanceMetric();
        BaseCaseDT base = !parallel ? (r_indx, q_indx) -> {
            double d = dm.dist(r_indx, N_r + q_indx, (List<? extends Vec>)allVecs, (List<Double>)wholeCache);
            ((BoundedSortedList)allPriorities.get(q_indx)).add(new IndexDistPair(r_indx, d));
            return d;
        } : (r_indx, q_indx) -> {
            BoundedSortedList target;
            double d = dm.dist(r_indx, N_r + q_indx, (List<? extends Vec>)allVecs, (List<Double>)wholeCache);
            BoundedSortedList boundedSortedList = target = (BoundedSortedList)allPriorities.get(q_indx);
            synchronized (boundedSortedList) {
                target.add(new IndexDistPair(r_indx, d));
            }
            return d;
        };
        ScoreDTLazy score = (ref, query, origScore) -> {
            if (origScore < 0.0) {
                return ref.minNodeDistance(query);
            }
            double bound_final = this.computeKnnBound(query, numNeighbors, allPriorities, query_B_cache);
            double d_min_b = origScore;
            if (Double.isFinite(bound_final) && d_min_b > bound_final) {
                return Double.NaN;
            }
            return d_min_b;
        };
        this.traverse(Q, base, score, true, parallel);
        neighbors.clear();
        distances.clear();
        for (int i3 = 0; i3 < Q.size(); ++i3) {
            IntList n = new IntList(numNeighbors);
            DoubleList d = new DoubleList(numNeighbors);
            BoundedSortedList knn = (BoundedSortedList)allPriorities.get(i3);
            for (int j = 0; j < knn.size(); ++j) {
                IndexDistPair ip = (IndexDistPair)knn.get(j);
                n.add(ip.getIndex());
                d.add(ip.getDist());
            }
            neighbors.add(n);
            distances.add(d);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    default public double computeKnnBound(IndexNode query, int numNeighbors, List<BoundedSortedList<IndexDistPair>> allPriorities, Map<IndexNode, Double> query_B_cache) {
        Object q_parrent;
        double lambda_q = query.furthestDescendantDistance();
        double bound_1 = Double.NEGATIVE_INFINITY;
        double bound_3 = Double.POSITIVE_INFINITY;
        for (int c = 0; c < query.numChildren(); ++c) {
            IndexNode n_c = query.getChild(c);
            double B_nc = query_B_cache.getOrDefault(n_c, Double.POSITIVE_INFINITY);
            bound_1 = Math.max(bound_1, B_nc);
            bound_3 = Math.min(bound_3, B_nc + 2.0 * Math.max(0.0, lambda_q - n_c.furthestDescendantDistance()));
        }
        double bound_2i = Double.POSITIVE_INFINITY;
        for (int p = 0; p < query.numPoints(); ++p) {
            BoundedSortedList<IndexDistPair> D_p;
            BoundedSortedList<IndexDistPair> boundedSortedList = D_p = allPriorities.get(query.getPoint(p));
            synchronized (boundedSortedList) {
                if (D_p.size() == numNeighbors) {
                    double d = D_p.last().dist;
                    bound_2i = Math.min(bound_2i, d);
                    bound_1 = Math.max(bound_1, d);
                } else {
                    bound_1 = Double.POSITIVE_INFINITY;
                }
                continue;
            }
        }
        if (Double.isInfinite(bound_1)) {
            bound_1 = Double.POSITIVE_INFINITY;
        }
        double bound_4 = (q_parrent = query.getParrent()) == null ? Double.POSITIVE_INFINITY : query_B_cache.getOrDefault(q_parrent, Double.POSITIVE_INFINITY);
        double bound_final = Math.min(Math.min(bound_1, bound_2i += query.furthestPointDistance() + lambda_q), Math.min(bound_3, bound_4));
        query_B_cache.put(query, bound_final);
        return bound_final;
    }

    @Override
    default public void search(VectorCollection<V> VC, double r_min, double r_max, List<List<Integer>> neighbors, List<List<Double>> distances, boolean parallel) {
        int i;
        if (!(VC instanceof DualTree)) {
            VectorCollection.super.search(VC, r_min, r_max, neighbors, distances, parallel);
            return;
        }
        DualTree Q = (DualTree)VC;
        neighbors.clear();
        distances.clear();
        for (int i2 = 0; i2 < Q.size(); ++i2) {
            neighbors.add(new IntList());
            distances.add(new DoubleList());
        }
        List<Double> this_cache = this.getAccelerationCache();
        List<Double> other_cache = Q.getAccelerationCache();
        int N_r = this.size();
        List<Double> wholeCache = this_cache == null ? null : ListUtils.mergedView(this_cache, other_cache);
        ArrayList allVecs = new ArrayList(N_r + Q.size());
        for (i = 0; i < N_r; ++i) {
            allVecs.add(this.get(i));
        }
        for (i = 0; i < Q.size(); ++i) {
            allVecs.add(Q.get(i));
        }
        DistanceMetric dm = this.getDistanceMetric();
        BaseCaseDT base = (r_indx, q_indx) -> {
            double d = dm.dist(r_indx, N_r + q_indx, (List<? extends Vec>)allVecs, wholeCache);
            if (r_min <= d && d <= r_max) {
                List list = (List)neighbors.get(q_indx);
                synchronized (list) {
                    ((List)neighbors.get(q_indx)).add(r_indx);
                    ((List)distances.get(q_indx)).add(d);
                }
            }
            return d;
        };
        ScoreDT score = (ref, query) -> {
            double[] minMax = ref.minMaxDistance(query);
            double d_min = minMax[0];
            double d_max = minMax[1];
            if (d_min > r_max || d_max < r_min) {
                return Double.NaN;
            }
            if (r_min < d_min && d_max < r_max) {
                IntList r_dec = new IntList();
                Iterator<Integer> iter = ref.DescendantIterator();
                while (iter.hasNext()) {
                    r_dec.add(iter.next());
                }
                IntList q_dec = new IntList();
                Iterator<Integer> iter2 = query.DescendantIterator();
                while (iter2.hasNext()) {
                    q_dec.add(iter2.next());
                }
                Iterator iterator = r_dec.iterator();
                while (iterator.hasNext()) {
                    int i = (Integer)iterator.next();
                    Iterator iterator2 = q_dec.iterator();
                    while (iterator2.hasNext()) {
                        int j = (Integer)iterator2.next();
                        double d = dm.dist(i, N_r + j, (List<? extends Vec>)allVecs, wholeCache);
                        List list = (List)neighbors.get(j);
                        synchronized (list) {
                            ((List)neighbors.get(j)).add(i);
                            ((List)distances.get(j)).add(d);
                        }
                    }
                }
                return Double.NaN;
            }
            return d_min;
        };
        this.traverse(Q, base, score, false, parallel);
        for (int i3 = 0; i3 < neighbors.size(); ++i3) {
            IndexTable it = new IndexTable(distances.get(i3));
            it.apply(distances.get(i3));
            it.apply(neighbors.get(i3));
        }
    }

    default public void traverse(DualTree<V> Q, BaseCaseDT base, ScoreDT score, boolean improvedTraverse, boolean parallel) {
        SelfAsChildNode<IndexNode> R_root = this.getRoot();
        SelfAsChildNode<IndexNode> Q_root = Q.getRoot();
        if (!this.getRoot().allPointsInLeaves()) {
            R_root = new SelfAsChildNode<IndexNode>(this.getRoot());
            Q_root = new SelfAsChildNode<IndexNode>(Q.getRoot());
        }
        if (parallel) {
            ForkJoinPool.commonPool().invoke(new DualTreeTraversalAction(R_root, Q_root, base, score, improvedTraverse));
        } else {
            DualTree.dual_depth_first(R_root, Q_root, base, score, improvedTraverse);
        }
    }

    public static void dual_depth_first(IndexNode n_r, IndexNode n_q, BaseCaseDT base, ScoreDT score, boolean improvedSearch) {
        int i;
        int j;
        for (int i2 = 0; i2 < n_r.numPoints(); ++i2) {
            for (j = 0; j < n_q.numPoints(); ++j) {
                base.base_case(n_r.getPoint(i2), n_q.getPoint(j));
            }
        }
        PriorityQueue<IndexTuple> q = new PriorityQueue<IndexTuple>();
        if (n_q.hasChildren() && n_r.hasChildren()) {
            if (!improvedSearch) {
                for (i = 0; i < n_r.numChildren(); ++i) {
                    for (int j2 = 0; j2 < n_q.numChildren(); ++j2) {
                        IndexNode n_q_j;
                        IndexNode n_r_i = n_r.getChild(i);
                        double s = score.score(n_r_i, n_q_j = n_q.getChild(j2), -1.0);
                        if (Double.isNaN(s)) continue;
                        q.offer(new IndexTuple(n_r_i, n_q_j, s));
                    }
                }
            } else {
                for (int c = 0; c < n_q.numChildren(); ++c) {
                    IndexNode n_q_c = n_q.getChild(c);
                    ArrayList<IndexTuple> q_qc = new ArrayList<IndexTuple>();
                    boolean all_scores_same = true;
                    for (int i3 = 0; i3 < n_r.numChildren(); ++i3) {
                        IndexNode n_r_i = n_r.getChild(i3);
                        double s = score.score(n_r_i, n_q_c, -1.0);
                        if (i3 > 0 && Math.abs(((IndexTuple)q_qc.get((int)(i3 - 1))).priority - s) < 1.0E-13) {
                            all_scores_same = false;
                        }
                        q_qc.add(new IndexTuple(n_r_i, n_q_c, s));
                    }
                    if (all_scores_same && ((IndexTuple)q_qc.get((int)0)).priority > 0.0) {
                        double s = score.score(n_r, n_q_c, -1.0);
                        if (s > ((IndexTuple)q_qc.get((int)0)).priority) {
                            q.offer(new IndexTuple(n_r, n_q_c, s));
                            continue;
                        }
                        q.addAll(q_qc);
                        continue;
                    }
                    q.addAll(q_qc);
                }
            }
        } else if (n_q.hasChildren()) {
            for (j = 0; j < n_q.numChildren(); ++j) {
                IndexNode n_q_j = n_q.getChild(j);
                double s = score.score(n_r, n_q_j, -1.0);
                if (Double.isNaN(s)) continue;
                q.offer(new IndexTuple(n_r, n_q_j, s));
            }
        } else if (n_r.hasChildren()) {
            for (i = 0; i < n_r.numChildren(); ++i) {
                IndexNode n_r_i = n_r.getChild(i);
                double s = score.score(n_r_i, n_q, -1.0);
                if (Double.isNaN(s)) continue;
                q.offer(new IndexTuple(n_r_i, n_q, s));
            }
        }
        while (!q.isEmpty()) {
            double s;
            IndexTuple toProccess = (IndexTuple)q.poll();
            if (score instanceof ScoreDTLazy && Double.isNaN(s = score.score(toProccess.a, toProccess.b, toProccess.priority))) continue;
            DualTree.dual_depth_first(toProccess.a, toProccess.b, base, score, improvedSearch);
        }
    }

    public static class DualTreeTraversalAction
    extends RecursiveAction
    implements Comparable<DualTreeTraversalAction> {
        IndexNode n_r;
        IndexNode n_q;
        BaseCaseDT base;
        ScoreDT score;
        boolean improvedSearch;
        double priority;

        public DualTreeTraversalAction(IndexNode n_r, IndexNode n_q, BaseCaseDT base, ScoreDT score, boolean improvedSearch) {
            this(n_r, n_q, base, score, improvedSearch, 0.0);
        }

        public DualTreeTraversalAction(IndexNode n_r, IndexNode n_q, BaseCaseDT base, ScoreDT score, boolean improvedSearch, double priority) {
            this.n_r = n_r;
            this.n_q = n_q;
            this.base = base;
            this.score = score;
            this.improvedSearch = improvedSearch;
            this.priority = priority;
        }

        @Override
        protected void compute() {
            int i;
            int j;
            double s;
            if (this.score instanceof ScoreDTLazy && Double.isNaN(s = this.score.score(this.n_r, this.n_q, this.priority))) {
                return;
            }
            for (int i2 = 0; i2 < this.n_r.numPoints(); ++i2) {
                for (j = 0; j < this.n_q.numPoints(); ++j) {
                    this.base.base_case(this.n_r.getPoint(i2), this.n_q.getPoint(j));
                }
            }
            PriorityQueue<DualTreeTraversalAction> q = new PriorityQueue<DualTreeTraversalAction>();
            if (this.n_q.hasChildren() && this.n_r.hasChildren()) {
                if (!this.improvedSearch) {
                    for (i = 0; i < this.n_r.numChildren(); ++i) {
                        for (int j2 = 0; j2 < this.n_q.numChildren(); ++j2) {
                            IndexNode n_q_j;
                            IndexNode n_r_i = this.n_r.getChild(i);
                            double s2 = this.score.score(n_r_i, n_q_j = this.n_q.getChild(j2), -1.0);
                            if (Double.isNaN(s2)) continue;
                            q.offer(new DualTreeTraversalAction(n_r_i, n_q_j, this.base, this.score, this.improvedSearch, s2));
                        }
                    }
                } else {
                    for (int c = 0; c < this.n_q.numChildren(); ++c) {
                        IndexNode n_q_c = this.n_q.getChild(c);
                        ArrayList<DualTreeTraversalAction> q_qc = new ArrayList<DualTreeTraversalAction>();
                        boolean all_scores_same = true;
                        for (int i3 = 0; i3 < this.n_r.numChildren(); ++i3) {
                            IndexNode n_r_i = this.n_r.getChild(i3);
                            double s3 = this.score.score(n_r_i, n_q_c, -1.0);
                            if (i3 > 0 && Math.abs(((DualTreeTraversalAction)q_qc.get((int)(i3 - 1))).priority - s3) < 1.0E-13) {
                                all_scores_same = false;
                            }
                            q_qc.add(new DualTreeTraversalAction(n_r_i, n_q_c, this.base, this.score, this.improvedSearch, s3));
                        }
                        if (all_scores_same) {
                            double s4 = this.score.score(this.n_r, n_q_c, -1.0);
                            if (s4 > ((DualTreeTraversalAction)q_qc.get((int)0)).priority) {
                                q.offer(new DualTreeTraversalAction(this.n_r, n_q_c, this.base, this.score, this.improvedSearch, s4));
                                continue;
                            }
                            q.addAll(q_qc);
                            continue;
                        }
                        q.addAll(q_qc);
                    }
                }
            } else if (this.n_q.hasChildren()) {
                for (j = 0; j < this.n_q.numChildren(); ++j) {
                    IndexNode n_q_j = this.n_q.getChild(j);
                    double s5 = this.score.score(this.n_r, n_q_j, -1.0);
                    if (Double.isNaN(s5)) continue;
                    q.offer(new DualTreeTraversalAction(this.n_r, n_q_j, this.base, this.score, this.improvedSearch, s5));
                }
            } else if (this.n_r.hasChildren()) {
                for (i = 0; i < this.n_r.numChildren(); ++i) {
                    IndexNode n_r_i = this.n_r.getChild(i);
                    double s6 = this.score.score(n_r_i, this.n_q, -1.0);
                    if (Double.isNaN(s6)) continue;
                    q.offer(new DualTreeTraversalAction(n_r_i, this.n_q, this.base, this.score, this.improvedSearch, s6));
                }
            }
            DualTreeTraversalAction.invokeAll(q);
        }

        @Override
        public int compareTo(DualTreeTraversalAction o) {
            return Double.compare(this.priority, o.priority);
        }
    }

    public static class SelfAsChildNode<N extends IndexNode<N>>
    implements IndexNode<SelfAsChildNode<N>> {
        public boolean asLeaf;
        N wrapping;

        public SelfAsChildNode(N wrapping) {
            this.wrapping = wrapping;
            this.asLeaf = !wrapping.hasChildren();
        }

        public SelfAsChildNode(boolean asLeaf, N wrapping) {
            this.asLeaf = asLeaf;
            this.wrapping = wrapping;
        }

        @Override
        public double furthestPointDistance() {
            if (!this.asLeaf) {
                return 0.0;
            }
            return this.wrapping.furthestPointDistance();
        }

        @Override
        public double furthestDescendantDistance() {
            if (this.asLeaf) {
                return this.wrapping.furthestPointDistance();
            }
            return this.wrapping.furthestDescendantDistance();
        }

        @Override
        public int numChildren() {
            if (this.asLeaf) {
                return 0;
            }
            return this.wrapping.numChildren() + 1;
        }

        @Override
        public IndexNode getChild(int indx) {
            if (indx == this.wrapping.numChildren()) {
                return new SelfAsChildNode<N>(true, this.wrapping);
            }
            return new SelfAsChildNode<IndexNode>(this.wrapping.getChild(indx));
        }

        @Override
        public Vec getVec(int indx) {
            return this.wrapping.getVec(indx);
        }

        @Override
        public int numPoints() {
            if (this.asLeaf) {
                return this.wrapping.numPoints();
            }
            return 0;
        }

        @Override
        public int getPoint(int indx) {
            if (this.asLeaf) {
                return this.wrapping.getPoint(indx);
            }
            throw new IndexOutOfBoundsException("Leaf node does not have any children");
        }

        @Override
        public SelfAsChildNode<N> getParrent() {
            if (this.asLeaf && this.wrapping.hasChildren()) {
                return new SelfAsChildNode<N>(false, this.wrapping);
            }
            Object parrent = this.wrapping.getParrent();
            if (parrent == null) {
                return null;
            }
            return new SelfAsChildNode(false, parrent);
        }

        @Override
        public double minNodeDistance(SelfAsChildNode<N> other) {
            return this.wrapping.minNodeDistance(other.wrapping);
        }

        @Override
        public double maxNodeDistance(SelfAsChildNode<N> other) {
            return this.wrapping.maxNodeDistance(other.wrapping);
        }

        @Override
        public double minNodeDistance(int other) {
            return this.wrapping.minNodeDistance(other);
        }

        public boolean equals(Object obj) {
            if (obj instanceof SelfAsChildNode) {
                SelfAsChildNode other = (SelfAsChildNode)obj;
                if (this.asLeaf == other.asLeaf) {
                    return this.wrapping.equals(other.wrapping);
                }
            }
            return false;
        }

        public int hashCode() {
            int hash = 5;
            hash = 71 * hash + (this.asLeaf ? 1 : 0);
            hash = 71 * hash + this.wrapping.hashCode();
            return hash;
        }

        @Override
        public double[] minMaxDistance(SelfAsChildNode<N> other) {
            return this.wrapping.minMaxDistance(other.wrapping);
        }
    }
}

