/*
 * Decompiled with CFR 0.152.
 */
package jsat.math.optimization;

import jsat.linear.Vec;
import jsat.math.Function;
import jsat.math.FunctionVec;
import jsat.math.optimization.LineSearch;

public class WolfeNWLineSearch
implements LineSearch {
    private double c1 = Math.nextUp(0.0f);
    private double c2 = Math.nextAfter(1.0f, Double.NEGATIVE_INFINITY);
    private AlphaInit initMethod = AlphaInit.METHOD1;
    double alpha_prev = -1.0;
    double f_x_prev = Double.NaN;
    double gradP_prev = Double.NaN;

    public WolfeNWLineSearch() {
        this(1.0E-4, 0.9);
    }

    public WolfeNWLineSearch(double c1, double c2) {
        this.setC1(c1);
        this.setC2(c2);
    }

    public void setC1(double c1) {
        if (c1 <= 0.0) {
            throw new IllegalArgumentException("c1 must be greater than 0, not " + c1);
        }
        if (c1 >= this.c2) {
            throw new IllegalArgumentException("c1 must be less than c2");
        }
        this.c1 = c1;
    }

    public double getC1() {
        return this.c1;
    }

    public void setC2(double c2) {
        if (c2 >= 1.0) {
            throw new IllegalArgumentException("c2 must be less than 1, not " + c2);
        }
        if (c2 <= this.c1) {
            throw new IllegalArgumentException("c2 must be greater than c1");
        }
        this.c2 = c2;
    }

    public double getC2() {
        return this.c2;
    }

    @Override
    public double lineSearch(double alpha_max, Vec x_k, Vec x_grad, Vec p_k, Function f, FunctionVec fp, double f_x, double gradP, Vec x_alpha_pk, double[] fxApRet, Vec grad_x_alpha_pk, boolean parallel) {
        if (Double.isNaN(f_x)) {
            f_x = f.f(x_k, parallel);
        }
        if (Double.isNaN(gradP)) {
            gradP = x_grad.dot(p_k);
        }
        double phi0 = f_x;
        double phi0P = gradP;
        double alpha_cur = 1.0;
        if (!Double.isNaN(this.gradP_prev) && this.initMethod == AlphaInit.METHOD1) {
            alpha_cur = this.alpha_prev * this.gradP_prev / gradP;
        } else if (!Double.isNaN(this.f_x_prev) && this.initMethod == AlphaInit.METHOD2) {
            alpha_cur = 2.0 * (f_x - this.f_x_prev) / phi0P;
            alpha_cur = Math.min(1.0, 1.01 * alpha_cur);
        }
        alpha_cur = Math.max(alpha_cur, 1.0E-13);
        this.alpha_prev = 0.0;
        double phi_prev = phi0;
        double phi_prevP = phi0P;
        double valToUse = 0.0;
        x_k.copyTo(x_alpha_pk);
        for (int iter = 1; iter <= 10 && valToUse == 0.0; ++iter) {
            x_alpha_pk.mutableAdd(alpha_cur - this.alpha_prev, p_k);
            double phi_cur = f.f(x_alpha_pk, parallel);
            if (fxApRet != null) {
                fxApRet[0] = phi_cur;
            }
            double phi_curP = fp.f(x_alpha_pk, grad_x_alpha_pk, parallel).dot(p_k);
            if (phi_cur > phi0 + this.c1 * alpha_cur * phi0P || phi_cur >= phi_prev && iter > 1) {
                valToUse = this.zoom(this.alpha_prev, alpha_cur, phi_prev, phi_cur, phi_prevP, phi_curP, phi0, phi0P, x_k, x_alpha_pk, p_k, f, fp, fxApRet, grad_x_alpha_pk, parallel);
                break;
            }
            if (Math.abs(phi_curP) <= -this.c2 * phi0P) {
                valToUse = alpha_cur;
                break;
            }
            if (phi_curP >= 0.0) {
                valToUse = this.zoom(alpha_cur, this.alpha_prev, phi_cur, phi_prev, phi_curP, phi_prevP, phi0, phi0P, x_k, x_alpha_pk, p_k, f, fp, fxApRet, grad_x_alpha_pk, parallel);
                break;
            }
            this.alpha_prev = alpha_cur;
            phi_prev = phi_cur;
            phi_prevP = phi_curP;
            if (!((alpha_cur *= 2.0) >= alpha_max)) continue;
            valToUse = alpha_max;
            break;
        }
        this.alpha_prev = valToUse;
        this.f_x_prev = f_x;
        this.gradP_prev = gradP;
        return valToUse;
    }

    private double zoom(double alphaLow, double alphaHi, double phi_alphaLow, double phi_alphaHigh, double phi_alphaLowP, double phi_alphaHighP, double phi0, double phi0P, Vec x, Vec x_alpha_p, Vec p, Function f, FunctionVec fp, double[] fxApRet, Vec grad_x_alpha_pk, boolean parallel) {
        double alpha_j = alphaLow;
        for (int iter = 0; iter < 10; ++iter) {
            double d1 = phi_alphaLowP + phi_alphaHighP - 3.0 * (phi_alphaLow - phi_alphaHigh) / (alphaLow - alphaHi);
            double d2 = Math.signum(alphaHi - alphaLow) * Math.pow(d1 * d1 - phi_alphaLowP * phi_alphaHighP, 0.5);
            alpha_j = alphaHi - (alphaHi - alphaLow) * (phi_alphaHighP + d2 - d1) / (phi_alphaHighP - phi_alphaLowP + 2.0 * d2);
            if (alpha_j - (alphaHi - alphaLow) / 2.0 * 0.1 < alphaLow || alpha_j > alphaHi * 0.9) {
                alpha_j = Math.min(alphaLow, alphaHi) + Math.abs(alphaHi - alphaLow) / 2.0;
            }
            x.copyTo(x_alpha_p);
            x_alpha_p.mutableAdd(alpha_j, p);
            double phi_j = f.f(x_alpha_p, parallel);
            if (fxApRet != null) {
                fxApRet[0] = phi_j;
            }
            double phi_jP = fp.f(x_alpha_p, grad_x_alpha_pk, parallel).dot(p);
            if (phi_j > phi0 + this.c1 * alpha_j * phi0 || phi_j >= phi_alphaLow) {
                alphaHi = alpha_j;
                phi_alphaHigh = phi_j;
                phi_alphaHighP = phi_jP;
                continue;
            }
            if (Math.abs(phi_jP) <= this.c2 * phi0P) {
                return alpha_j;
            }
            if (phi_jP * (alphaHi - alphaLow) >= 0.0) {
                alphaHi = alphaLow;
                phi_alphaHigh = phi_alphaLow;
                phi_alphaHighP = phi_alphaLowP;
            }
            alphaLow = alpha_j;
            phi_alphaLow = phi_j;
            phi_alphaLowP = phi_jP;
        }
        return alpha_j;
    }

    @Override
    public boolean updatesGrad() {
        return true;
    }

    @Override
    public WolfeNWLineSearch clone() {
        WolfeNWLineSearch clone = new WolfeNWLineSearch(this.c1, this.c2);
        clone.initMethod = this.initMethod;
        clone.alpha_prev = this.alpha_prev;
        clone.f_x_prev = this.f_x_prev;
        clone.gradP_prev = this.gradP_prev;
        return clone;
    }

    public static enum AlphaInit {
        METHOD1,
        METHOD2;

    }
}

