/*
 * Decompiled with CFR 0.152.
 */
package jsat.datatransform;

import java.util.ArrayList;
import jsat.DataSet;
import jsat.classifiers.DataPoint;
import jsat.datatransform.InvertibleTransform;
import jsat.datatransform.WhitenedPCA;
import jsat.datatransform.ZeroMeanTransform;
import jsat.exceptions.FailedToFitException;
import jsat.linear.DenseMatrix;
import jsat.linear.DenseVector;
import jsat.linear.Matrix;
import jsat.linear.MatrixOfVecs;
import jsat.linear.SingularValueDecomposition;
import jsat.linear.Vec;

public class FastICA
implements InvertibleTransform {
    private static final long serialVersionUID = -8644025740457515563L;
    private int C;
    private NegEntropyFunc G;
    private boolean preWhitened;
    private ZeroMeanTransform zeroMean;
    private Matrix unmixing;
    private Matrix mixing;

    public FastICA() {
        this(10);
    }

    public FastICA(int C2) {
        this(C2, DefaultNegEntropyFunc.LOG_COSH, false);
    }

    public FastICA(DataSet data, int C2) {
        this(data, C2, DefaultNegEntropyFunc.LOG_COSH, false);
    }

    public FastICA(int C2, NegEntropyFunc G, boolean preWhitened) {
        this.setC(C2);
        this.setNegEntropyFunction(G);
        this.setPreWhitened(preWhitened);
    }

    public FastICA(DataSet data, int C2, NegEntropyFunc G, boolean preWhitened) {
        this(C2, G, preWhitened);
        this.fit(data);
    }

    @Override
    public void fit(DataSet data) {
        Matrix X;
        int N = data.size();
        DenseVector tmp = new DenseVector(N);
        ArrayList<Vec> ws = new ArrayList<Vec>(this.C);
        WhitenedPCA whiten = null;
        if (!this.preWhitened) {
            this.zeroMean = new ZeroMeanTransform(data);
            data = data.shallowClone();
            data.applyTransform(this.zeroMean);
            whiten = new WhitenedPCA(data);
            data.applyTransform(whiten);
            X = data.getDataMatrixView();
        } else {
            X = data.getDataMatrixView();
        }
        int subD = X.cols();
        DenseVector w_tmp = new DenseVector(subD);
        int maxIter = 500;
        for (int p = 0; p < this.C; ++p) {
            Vec w_p = Vec.random(subD);
            w_p.normalize();
            int iter = 0;
            do {
                int i;
                w_p.copyTo(w_tmp);
                tmp.zeroOut();
                X.multiply(w_p, 1.0, tmp);
                double gwx_avg = 0.0;
                for (int i2 = 0; i2 < ((Vec)tmp).length(); ++i2) {
                    double x = ((Vec)tmp).get(i2);
                    double g = this.G.deriv1(x);
                    double gp = this.G.deriv2(x, g);
                    if (Double.isNaN(g) || Double.isInfinite(g) || Double.isNaN(gp) || Double.isNaN(gp)) {
                        throw new FailedToFitException("Encountered NaN or Inf in calculation");
                    }
                    ((Vec)tmp).set(i2, g);
                    gwx_avg += gp;
                }
                w_p.mutableMultiply(-(gwx_avg /= (double)N));
                X.transposeMultiply(1.0 / (double)N, tmp, w_p);
                double[] coefs = new double[ws.size()];
                for (i = 0; i < coefs.length; ++i) {
                    coefs[i] = w_p.dot((Vec)ws.get(i));
                }
                for (i = 0; i < coefs.length; ++i) {
                    w_p.mutableAdd(-coefs[i], (Vec)ws.get(i));
                }
                w_p.normalize();
            } while (Math.abs(1.0 - Math.abs(w_p.dot(w_tmp))) > 1.0E-6 && iter++ < maxIter);
            ws.add(w_p);
        }
        if (!this.preWhitened) {
            MatrixOfVecs W = new MatrixOfVecs(ws);
            this.unmixing = W.multiply(whiten.transform).transpose();
        } else {
            this.unmixing = new DenseMatrix(new MatrixOfVecs(ws)).transpose();
        }
        this.mixing = new SingularValueDecomposition(this.unmixing.clone()).getPseudoInverse();
    }

    public FastICA(FastICA toCopy) {
        this.C = toCopy.C;
        this.G = toCopy.G;
        this.preWhitened = toCopy.preWhitened;
        if (toCopy.zeroMean != null) {
            this.zeroMean = toCopy.zeroMean.clone();
        }
        if (toCopy.unmixing != null) {
            this.unmixing = toCopy.unmixing.clone();
        }
        if (toCopy.mixing != null) {
            this.mixing = toCopy.mixing.clone();
        }
    }

    public void setC(int C2) {
        if (C2 < 1) {
            throw new IllegalArgumentException("Number of components must be positive, not " + C2);
        }
        this.C = C2;
    }

    public int getC() {
        return this.C;
    }

    public void setNegEntropyFunction(NegEntropyFunc G) {
        if (G == null) {
            throw new NullPointerException("Negative Entropy function must be non-null");
        }
        this.G = G;
    }

    public NegEntropyFunc getNegEntropyFunction() {
        return this.G;
    }

    public void setPreWhitened(boolean preWhitened) {
        this.preWhitened = preWhitened;
    }

    public boolean isPreWhitened() {
        return this.preWhitened;
    }

    @Override
    public DataPoint transform(DataPoint dp) {
        Vec x = this.zeroMean != null ? this.zeroMean.transform(dp).getNumericalValues() : dp.getNumericalValues();
        Vec newX = x.multiply(this.unmixing);
        return new DataPoint(newX, dp.getCategoricalValues(), dp.getCategoricalData());
    }

    @Override
    public DataPoint inverse(DataPoint dp) {
        Vec x = dp.getNumericalValues();
        x = x.multiply(this.mixing);
        DataPoint toRet = new DataPoint(x, dp.getCategoricalValues(), dp.getCategoricalData());
        if (this.zeroMean != null) {
            this.zeroMean.mutableInverse(toRet);
        }
        return toRet;
    }

    @Override
    public FastICA clone() {
        return new FastICA(this);
    }

    public static enum DefaultNegEntropyFunc implements NegEntropyFunc
    {
        LOG_COSH{

            @Override
            public double deriv1(double x) {
                return Math.tanh(x);
            }

            @Override
            public double deriv2(double x, double d1) {
                return 1.0 - d1 * d1;
            }
        }
        ,
        EXP{

            @Override
            public double deriv1(double x) {
                return x * Math.exp(-x * x / 2.0);
            }

            @Override
            public double deriv2(double x, double d1) {
                if (x == 0.0) {
                    return 1.0;
                }
                return (1.0 - x * x) * (d1 / x);
            }
        }
        ,
        KURTOSIS{

            @Override
            public double deriv1(double x) {
                return x * x * x;
            }

            @Override
            public double deriv2(double x, double d1) {
                return x * x * 3.0;
            }
        };


        @Override
        public abstract double deriv1(double var1);

        @Override
        public abstract double deriv2(double var1, double var3);
    }

    public static interface NegEntropyFunc {
        public double deriv1(double var1);

        public double deriv2(double var1, double var3);
    }
}

