/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers;

import java.util.Collections;
import java.util.List;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.DataPoint;
import jsat.classifiers.calibration.BinaryScoreClassifier;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;

public class RegressorToClassifier
implements BinaryScoreClassifier,
Parameterized {
    private static final long serialVersionUID = -2607433019826385335L;
    private Regressor regressor;

    public RegressorToClassifier(Regressor regressor) {
        this.regressor = regressor;
    }

    @Override
    public double getScore(DataPoint dp) {
        return this.regressor.regress(dp);
    }

    @Override
    public RegressorToClassifier clone() {
        return new RegressorToClassifier(this.regressor.clone());
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        CategoricalResults cr = new CategoricalResults(2);
        if (this.getScore(data) > 0.0) {
            cr.setProb(1, 1.0);
        } else {
            cr.setProb(0, 1.0);
        }
        return cr;
    }

    @Override
    public void train(ClassificationDataSet dataSet, boolean parallel) {
        RegressionDataSet rds = this.getRegressionDataSet(dataSet);
        this.regressor.train(rds, parallel);
    }

    @Override
    public void train(ClassificationDataSet dataSet) {
        RegressionDataSet rds = this.getRegressionDataSet(dataSet);
        this.regressor.train(rds);
    }

    @Override
    public boolean supportsWeightedData() {
        return this.regressor.supportsWeightedData();
    }

    private RegressionDataSet getRegressionDataSet(ClassificationDataSet dataSet) {
        RegressionDataSet rds = new RegressionDataSet(dataSet.getNumNumericalVars(), dataSet.getCategories());
        for (int i = 0; i < dataSet.size(); ++i) {
            rds.addDataPoint(dataSet.getDataPoint(i), (double)(dataSet.getDataPointCategory(i) * 2 - 1));
        }
        return rds;
    }

    @Override
    public List<Parameter> getParameters() {
        if (this.regressor instanceof Parameterized) {
            return ((Parameterized)((Object)this.regressor)).getParameters();
        }
        return Collections.EMPTY_LIST;
    }

    @Override
    public Parameter getParameter(String paramName) {
        if (this.regressor instanceof Parameterized) {
            return ((Parameterized)((Object)this.regressor)).getParameter(paramName);
        }
        return null;
    }
}

