/*
 * Decompiled with CFR 0.152.
 */
package jsat.datatransform;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import jsat.DataSet;
import jsat.classifiers.DataPoint;
import jsat.datatransform.DataTransformBase;
import jsat.distributions.Distribution;
import jsat.distributions.LogUniform;
import jsat.linear.DenseMatrix;
import jsat.linear.DenseVector;
import jsat.linear.IndexValue;
import jsat.linear.Matrix;
import jsat.linear.RandomMatrix;
import jsat.linear.Vec;
import jsat.utils.IntList;
import jsat.utils.random.RandomUtil;

public class JLTransform
extends DataTransformBase {
    private static final long serialVersionUID = -8621368067861343913L;
    private List<IntList> sparse_jl_map;
    private double sparse_jl_cnst;
    private TransformMode mode;
    private Matrix R;
    private int k;
    private boolean inMemory;

    protected JLTransform(JLTransform transform) {
        this.mode = transform.mode;
        this.R = transform.R.clone();
        this.k = transform.k;
        if (transform.sparse_jl_map != null) {
            this.sparse_jl_map = new ArrayList<IntList>(transform.sparse_jl_map.size());
            for (IntList a : transform.sparse_jl_map) {
                this.sparse_jl_map.add(new IntList(a));
            }
        }
        this.sparse_jl_cnst = transform.sparse_jl_cnst;
    }

    public JLTransform() {
        this(50);
    }

    public JLTransform(int k) {
        this(k, TransformMode.SPARSE_SQRT);
    }

    public JLTransform(int k, TransformMode mode) {
        this(k, mode, true);
    }

    public JLTransform(int k, TransformMode mode, boolean inMemory) {
        this.mode = mode;
        this.k = k;
        this.inMemory = inMemory;
    }

    @Override
    public void fit(DataSet data) {
        int d = data.getNumNumericalVars();
        Random rand = RandomUtil.getRandom();
        this.R = new RandomMatrixJL(this.k, d, rand.nextLong(), this.mode);
        RandomMatrixJL oldR = this.R;
        if (this.mode == TransformMode.GAUSS) {
            if (this.inMemory) {
                this.R = new DenseMatrix(this.k, d);
                this.R.mutableAdd(oldR);
            }
        } else {
            int s;
            switch (this.mode) {
                case SPARSE_SQRT: {
                    s = (int)Math.round(Math.sqrt(d + 1));
                    break;
                }
                case SPARSE_LOG: {
                    s = (int)Math.round((double)d / Math.log(d + 1));
                    break;
                }
                default: {
                    s = 3;
                }
            }
            this.sparse_jl_cnst = Math.sqrt(s);
            this.sparse_jl_map = new ArrayList<IntList>(d);
            IntList all_embed_dims = IntList.range(0, this.k);
            int nnz = this.k / s;
            for (int j = 0; j < d; ++j) {
                int i;
                Collections.shuffle(all_embed_dims, rand);
                IntList x_j_map = new IntList(nnz);
                for (i = 0; i < nnz; ++i) {
                    x_j_map.add(i);
                }
                for (i = nnz / 2; i < nnz; ++i) {
                    x_j_map.add(-(i + 1));
                }
                Collections.sort(x_j_map, (o1, o2) -> Integer.compare(Math.abs(o1), Math.abs(o2)));
                this.sparse_jl_map.add(x_j_map);
            }
        }
    }

    public void setMode(TransformMode mode) {
        this.mode = mode;
    }

    public TransformMode getMode() {
        return this.mode;
    }

    public void setInMemory(boolean inMemory) {
        this.inMemory = inMemory;
    }

    public boolean isInMemory() {
        return this.inMemory;
    }

    public void setProjectedDimension(int k) {
        this.k = k;
    }

    public int getProjectedDimension() {
        return this.k;
    }

    public static Distribution guessProjectedDimension(DataSet d) {
        double max = 100.0;
        double min = 10.0;
        if (d.getNumNumericalVars() > 10000) {
            min = 100.0;
            max = 1000.0;
        }
        return new LogUniform(min, max);
    }

    @Override
    public DataPoint transform(DataPoint dp) {
        Vec newVec;
        switch (this.mode) {
            case SPARSE_SQRT: 
            case SPARSE_LOG: 
            case SPARSE: {
                newVec = new DenseVector(this.k);
                for (IndexValue iv : dp.getNumericalValues()) {
                    double x_i = iv.getValue();
                    int i = iv.getIndex();
                    Iterator iterator = this.sparse_jl_map.get(i).iterator();
                    while (iterator.hasNext()) {
                        int j = (Integer)iterator.next();
                        if (j >= 0) {
                            newVec.increment(j, x_i);
                            continue;
                        }
                        newVec.increment(-j - 1, -x_i);
                    }
                    newVec.mutableMultiply(this.sparse_jl_cnst);
                }
                break;
            }
            default: {
                newVec = dp.getNumericalValues();
                newVec = this.R.multiply(newVec);
            }
        }
        DataPoint newDP = new DataPoint(newVec, dp.getCategoricalValues(), dp.getCategoricalData());
        return newDP;
    }

    @Override
    public JLTransform clone() {
        return new JLTransform(this);
    }

    private static class RandomMatrixJL
    extends RandomMatrix {
        private static final long serialVersionUID = 2009377824896155918L;
        public double cnst;
        private TransformMode mode;

        public RandomMatrixJL(RandomMatrixJL toCopy) {
            super(toCopy);
            this.cnst = toCopy.cnst;
            this.mode = toCopy.mode;
        }

        public RandomMatrixJL(int rows, int cols, long XORSeed, TransformMode mode) {
            super(rows, cols, XORSeed);
            this.mode = mode;
            int k = rows;
            if (mode == TransformMode.GAUSS || mode == TransformMode.BINARY) {
                this.cnst = 1.0 / Math.sqrt(k);
            } else if (mode == TransformMode.SPARSE) {
                this.cnst = Math.sqrt(3.0) / Math.sqrt(k);
            }
        }

        @Override
        protected double getVal(Random rand) {
            if (this.mode == TransformMode.GAUSS) {
                return rand.nextGaussian() * this.cnst;
            }
            if (this.mode == TransformMode.BINARY) {
                return rand.nextBoolean() ? -this.cnst : this.cnst;
            }
            if (this.mode == TransformMode.SPARSE) {
                int val = rand.nextInt(6);
                if (val == 0) {
                    return -this.cnst;
                }
                if (val == 1) {
                    return this.cnst;
                }
                return 0.0;
            }
            throw new RuntimeException("BUG: Please report");
        }

        @Override
        public RandomMatrixJL clone() {
            return new RandomMatrixJL(this);
        }
    }

    public static enum TransformMode {
        GAUSS,
        BINARY,
        SPARSE,
        SPARSE_SQRT,
        SPARSE_LOG;

    }
}

