/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.neuralnetwork;

import jsat.SingleWeightVectorModel;
import jsat.classifiers.BaseUpdateableClassifier;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.DataPoint;
import jsat.classifiers.calibration.BinaryScoreClassifier;
import jsat.exceptions.FailedToFitException;
import jsat.linear.DenseVector;
import jsat.linear.Vec;

public class Perceptron
extends BaseUpdateableClassifier
implements BinaryScoreClassifier,
SingleWeightVectorModel {
    private static final long serialVersionUID = -3605237847981632020L;
    private double learningRate;
    private double bias;
    private Vec weights;

    public Perceptron() {
        this(0.1, 20);
    }

    public Perceptron(double learningRate, int iteratinLimit) {
        if (learningRate <= 0.0 || learningRate > 1.0) {
            throw new RuntimeException("Preceptron learning rate must be in the range (0,1]");
        }
        this.learningRate = learningRate;
        this.setEpochs(this.epochs);
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        CategoricalResults cr = new CategoricalResults(2);
        cr.setProb(this.output(data), 1.0);
        return cr;
    }

    @Override
    public double getScore(DataPoint dp) {
        return this.weights.dot(dp.getNumericalValues()) + this.bias;
    }

    @Override
    public void setUp(CategoricalData[] categoricalAttributes, int numericAttributes, CategoricalData predicting) {
        if (predicting.getNumOfCategories() != 2) {
            throw new FailedToFitException("Perceptrion is for binary problems only");
        }
        this.weights = new DenseVector(numericAttributes);
        this.bias = 0.0;
    }

    @Override
    public void update(DataPoint dataPoint, double weight, int targetClass) {
        if (this.classify(dataPoint).mostLikely() == targetClass) {
            return;
        }
        double c = (double)(targetClass * 2 - 1) * this.learningRate;
        this.weights.mutableAdd(c, dataPoint.getNumericalValues());
        this.bias += c;
    }

    private int output(DataPoint input) {
        double dot = this.getScore(input);
        return dot >= 0.0 ? 1 : 0;
    }

    @Override
    public boolean supportsWeightedData() {
        return true;
    }

    @Override
    public Vec getRawWeight() {
        return this.weights;
    }

    @Override
    public double getBias() {
        return this.bias;
    }

    @Override
    public Vec getRawWeight(int index) {
        if (index < 1) {
            return this.getRawWeight();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override
    public double getBias(int index) {
        if (index < 1) {
            return this.getBias();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override
    public int numWeightsVecs() {
        return 1;
    }

    @Override
    public Perceptron clone() {
        Perceptron copy = new Perceptron(this.learningRate, this.epochs);
        if (this.weights != null) {
            copy.weights = this.weights.clone();
        }
        copy.bias = this.bias;
        return copy;
    }
}

