/*
 * Decompiled with CFR 0.152.
 */
package jsat.linear;

import java.util.Iterator;
import jsat.linear.IndexValue;
import jsat.linear.Vec;
import jsat.math.Function1D;

public class VecOps {
    private static final IndexValue badIV = new IndexValue(-1, Double.NaN);

    public static double accumulateSum(Vec w, Vec x, Vec y, Function1D f) {
        if (w.length() != x.length() || x.length() != y.length()) {
            throw new ArithmeticException("All 3 vector inputs must have equal lengths");
        }
        double val = 0.0;
        boolean skipZeros = f.f(0.0) == 0.0;
        boolean wSparse = w.isSparse();
        boolean xSparse = x.isSparse();
        boolean ySparse = y.isSparse();
        if (wSparse && !xSparse && !ySparse) {
            for (IndexValue wiv : w) {
                int idx = wiv.getIndex();
                val += wiv.getValue() * f.f(x.get(idx) - y.get(idx));
            }
        } else if (!(wSparse || xSparse || ySparse)) {
            for (int i = 0; i < w.length(); ++i) {
                val += w.get(i) * f.f(x.get(i) - y.get(i));
            }
        } else {
            Iterator<IndexValue> xIter = x.iterator();
            Iterator<IndexValue> yIter = y.iterator();
            IndexValue xiv = xIter.hasNext() ? xIter.next() : badIV;
            IndexValue yiv = yIter.hasNext() ? yIter.next() : badIV;
            for (IndexValue wiv : w) {
                int index = wiv.getIndex();
                double w_i = wiv.getValue();
                while (xiv.getIndex() < index && xIter.hasNext()) {
                    xiv = xIter.next();
                }
                while (yiv.getIndex() < index && yIter.hasNext()) {
                    yiv = yIter.next();
                }
                double x_i = xiv.getIndex() == index ? xiv.getValue() : 0.0;
                double y_i = yiv.getIndex() == index ? yiv.getValue() : 0.0;
                if (skipZeros && x_i == 0.0 && y_i == 0.0) continue;
                val += w_i * f.f(x_i - y_i);
            }
        }
        return val;
    }

    public static double weightedDot(Vec w, Vec x, Vec y) {
        if (w.length() != x.length() || x.length() != y.length()) {
            throw new ArithmeticException("All 3 vector inputs must have equal lengths");
        }
        double sum = 0.0;
        if (x.isSparse() && y.isSparse()) {
            IndexValue yiv;
            Iterator<IndexValue> xIter = x.iterator();
            Iterator<IndexValue> yIter = y.iterator();
            IndexValue xiv = xIter.hasNext() ? xIter.next() : badIV;
            IndexValue indexValue = yiv = yIter.hasNext() ? yIter.next() : badIV;
            while (xiv != badIV && yiv != badIV) {
                if (xiv.getIndex() < yiv.getIndex()) {
                    xiv = xIter.hasNext() ? xIter.next() : badIV;
                    continue;
                }
                if (xiv.getIndex() > yiv.getIndex()) {
                    yiv = yIter.hasNext() ? yIter.next() : badIV;
                    continue;
                }
                sum += w.get(xiv.getIndex()) * xiv.getValue() * yiv.getValue();
                xiv = xIter.hasNext() ? xIter.next() : badIV;
                yiv = yIter.hasNext() ? yIter.next() : badIV;
            }
        } else if (x.isSparse()) {
            for (IndexValue iv : x) {
                int indx = iv.getIndex();
                sum += w.get(indx) * iv.getValue() * y.get(indx);
            }
        } else {
            if (y.isSparse()) {
                return VecOps.weightedDot(w, y, x);
            }
            for (int i = 0; i < w.length(); ++i) {
                sum += w.get(i) * x.get(i) * y.get(i);
            }
        }
        return sum;
    }
}

