/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.bayesian;

import java.util.Arrays;
import jsat.classifiers.BaseUpdateableClassifier;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.DataPoint;
import jsat.distributions.Normal;
import jsat.exceptions.UntrainedModelException;
import jsat.linear.IndexValue;
import jsat.linear.Vec;
import jsat.math.MathTricks;
import jsat.math.OnLineStatistics;

public class NaiveBayesUpdateable
extends BaseUpdateableClassifier {
    private static final long serialVersionUID = 1835073945715343486L;
    private double[][][] apriori;
    private OnLineStatistics[][] valueStats;
    private double priorSum = 0.0;
    private double[] priors;
    private boolean sparseInput = true;

    public NaiveBayesUpdateable() {
        this(true);
    }

    public NaiveBayesUpdateable(boolean sparse) {
        this.setSparse(sparse);
    }

    protected NaiveBayesUpdateable(NaiveBayesUpdateable other) {
        this(other.sparseInput);
        if (other.apriori != null) {
            this.apriori = new double[other.apriori.length][][];
            this.valueStats = new OnLineStatistics[other.valueStats.length][];
            for (int i = 0; i < other.apriori.length; ++i) {
                int j;
                this.apriori[i] = new double[other.apriori[i].length][];
                for (j = 0; j < other.apriori[i].length; ++j) {
                    this.apriori[i][j] = Arrays.copyOf(other.apriori[i][j], other.apriori[i][j].length);
                }
                this.valueStats[i] = new OnLineStatistics[other.valueStats[i].length];
                for (j = 0; j < this.valueStats[i].length; ++j) {
                    this.valueStats[i][j] = new OnLineStatistics(other.valueStats[i][j]);
                }
            }
            this.priorSum = other.priorSum;
            this.priors = Arrays.copyOf(other.priors, other.priors.length);
        }
    }

    @Override
    public NaiveBayesUpdateable clone() {
        return new NaiveBayesUpdateable(this);
    }

    @Override
    public void setUp(CategoricalData[] categoricalAttributes, int numericAttributes, CategoricalData predicting) {
        int nCat = predicting.getNumOfCategories();
        this.apriori = new double[nCat][categoricalAttributes.length][];
        this.valueStats = new OnLineStatistics[nCat][numericAttributes];
        this.priors = new double[nCat];
        this.priorSum = nCat;
        Arrays.fill(this.priors, 1.0);
        for (int i = 0; i < nCat; ++i) {
            int j;
            for (j = 0; j < categoricalAttributes.length; ++j) {
                this.apriori[i][j] = new double[categoricalAttributes[j].getNumOfCategories()];
                for (int z = 0; z < this.apriori[i][j].length; ++z) {
                    this.apriori[i][j][z] = 1.0;
                }
            }
            for (j = 0; j < numericAttributes; ++j) {
                this.valueStats[i][j] = new OnLineStatistics();
            }
        }
    }

    @Override
    public void update(DataPoint dataPoint, double weight, int targetClass) {
        Vec x = dataPoint.getNumericalValues();
        if (this.sparseInput) {
            for (IndexValue iv : x) {
                this.valueStats[targetClass][iv.getIndex()].add(iv.getValue(), weight);
            }
        } else {
            for (int j = 0; j < x.length(); ++j) {
                this.valueStats[targetClass][j].add(x.get(j), weight);
            }
        }
        int[] catValues = dataPoint.getCategoricalValues();
        for (int j = 0; j < this.apriori[targetClass].length; ++j) {
            double[] dArray = this.apriori[targetClass][j];
            int n = catValues[j];
            dArray[n] = dArray[n] + 1.0;
        }
        this.priorSum += 1.0;
        int n = targetClass;
        this.priors[n] = this.priors[n] + 1.0;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        if (this.apriori == null) {
            throw new UntrainedModelException("Model has not been intialized");
        }
        CategoricalResults results = new CategoricalResults(this.apriori.length);
        double[] logProbs = new double[this.apriori.length];
        double maxLogProg = Double.NEGATIVE_INFINITY;
        Vec numVals = data.getNumericalValues();
        for (int i = 0; i < this.valueStats.length; ++i) {
            int j;
            double logProb = 0.0;
            if (this.sparseInput) {
                for (IndexValue iv : numVals) {
                    int indx = iv.getIndex();
                    double mean = this.valueStats[i][indx].getMean();
                    double stndDev = this.valueStats[i][indx].getStandardDeviation();
                    double logPDF = Normal.logPdf(iv.getValue(), mean, stndDev);
                    if (Double.isNaN(logPDF)) {
                        logProb += Math.log(1.0E-16);
                        continue;
                    }
                    if (Double.isInfinite(logPDF)) {
                        logProb += Math.log(1.0E-16);
                        continue;
                    }
                    logProb += logPDF;
                }
            } else {
                for (j = 0; j < this.valueStats[i].length; ++j) {
                    double mean = this.valueStats[i][j].getMean();
                    double stdDev = this.valueStats[i][j].getStandardDeviation();
                    double logPDF = Normal.logPdf(numVals.get(j), mean, stdDev);
                    if (Double.isInfinite(logPDF)) {
                        logProb += Math.log(1.0E-16);
                        continue;
                    }
                    logProb += logPDF;
                }
            }
            for (j = 0; j < this.apriori[i].length; ++j) {
                double sum = 0.0;
                for (int z = 0; z < this.apriori[i][j].length; ++z) {
                    sum += this.apriori[i][j][z];
                }
                double p = this.apriori[i][j][data.getCategoricalValue(j)];
                logProb += Math.log(p / sum);
            }
            logProbs[i] = logProb += Math.log(this.priors[i] / this.priorSum);
            maxLogProg = Math.max(maxLogProg, logProb);
        }
        double denom = MathTricks.logSumExp(logProbs, maxLogProg);
        for (int i = 0; i < results.size(); ++i) {
            results.setProb(i, Math.exp(logProbs[i] - denom));
        }
        results.normalize();
        return results;
    }

    @Override
    public boolean supportsWeightedData() {
        return true;
    }

    public boolean isSparseInput() {
        return this.sparseInput;
    }

    public void setSparse(boolean sparseInput) {
        this.sparseInput = sparseInput;
    }
}

