/*
 * Decompiled with CFR 0.152.
 */
package jsat.parameters;

import java.util.ArrayList;
import java.util.List;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.evaluation.Accuracy;
import jsat.classifiers.evaluation.ClassificationScore;
import jsat.exceptions.FailedToFitException;
import jsat.exceptions.UntrainedModelException;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.regression.Regressor;
import jsat.regression.evaluation.MeanSquaredError;
import jsat.regression.evaluation.RegressionScore;

public abstract class ModelSearch
implements Classifier,
Regressor {
    protected Classifier baseClassifier;
    protected Classifier trainedClassifier;
    protected ClassificationScore classificationTargetScore = new Accuracy();
    protected RegressionScore regressionTargetScore = new MeanSquaredError(true);
    protected Regressor baseRegressor;
    protected Regressor trainedRegressor;
    protected List<Parameter> searchParams;
    protected int folds;
    protected boolean trainModelsInParallel = true;
    protected boolean trainFinalModel = true;
    protected boolean reuseSameCVFolds = true;

    public ModelSearch(Regressor baseRegressor, int folds) {
        if (!(baseRegressor instanceof Parameterized)) {
            throw new FailedToFitException("Given regressor does not support parameterized alterations");
        }
        this.baseRegressor = baseRegressor;
        if (baseRegressor instanceof Classifier) {
            this.baseClassifier = (Classifier)((Object)baseRegressor);
        }
        this.searchParams = new ArrayList<Parameter>();
        this.folds = folds;
    }

    public ModelSearch(Classifier baseClassifier, int folds) {
        if (!(baseClassifier instanceof Parameterized)) {
            throw new FailedToFitException("Given classifier does not support parameterized alterations");
        }
        this.baseClassifier = baseClassifier;
        if (baseClassifier instanceof Regressor) {
            this.baseRegressor = (Regressor)((Object)baseClassifier);
        }
        this.searchParams = new ArrayList<Parameter>();
        this.folds = folds;
    }

    public ModelSearch(ModelSearch toCopy) {
        if (toCopy.baseClassifier != null) {
            this.baseClassifier = toCopy.baseClassifier.clone();
            if (this.baseClassifier instanceof Regressor) {
                this.baseRegressor = (Regressor)((Object)this.baseClassifier);
            }
        } else {
            this.baseRegressor = toCopy.baseRegressor.clone();
            if (this.baseRegressor instanceof Classifier) {
                this.baseClassifier = (Classifier)((Object)this.baseRegressor);
            }
        }
        if (toCopy.trainedClassifier != null) {
            this.trainedClassifier = toCopy.trainedClassifier.clone();
        }
        if (toCopy.trainedRegressor != null) {
            this.trainedRegressor = toCopy.trainedRegressor.clone();
        }
        this.searchParams = new ArrayList<Parameter>();
        for (Parameter p : toCopy.searchParams) {
            this.searchParams.add(this.getParameterByName(p.getName()));
        }
        this.folds = toCopy.folds;
    }

    public void setTrainModelsInParallel(boolean trainInParallel) {
        this.trainModelsInParallel = trainInParallel;
    }

    public boolean isTrainModelsInParallel() {
        return this.trainModelsInParallel;
    }

    public void setTrainFinalModel(boolean trainFinalModel) {
        this.trainFinalModel = trainFinalModel;
    }

    public boolean isTrainFinalModel() {
        return this.trainFinalModel;
    }

    public void setReuseSameCVFolds(boolean reuseSameSplit) {
        this.reuseSameCVFolds = reuseSameSplit;
    }

    public boolean isReuseSameCVFolds() {
        return this.reuseSameCVFolds;
    }

    public Classifier getBaseClassifier() {
        return this.baseClassifier;
    }

    public Classifier getTrainedClassifier() {
        return this.trainedClassifier;
    }

    public Regressor getBaseRegressor() {
        return this.baseRegressor;
    }

    public Regressor getTrainedRegressor() {
        return this.trainedRegressor;
    }

    public void setClassificationTargetScore(ClassificationScore classifierTargetScore) {
        this.classificationTargetScore = classifierTargetScore;
    }

    public ClassificationScore getClassificationTargetScore() {
        return this.classificationTargetScore;
    }

    public void setRegressionTargetScore(RegressionScore regressionTargetScore) {
        this.regressionTargetScore = regressionTargetScore;
    }

    public RegressionScore getRegressionTargetScore() {
        return this.regressionTargetScore;
    }

    protected Parameter getParameterByName(String name) throws IllegalArgumentException {
        Parameter param = this.baseClassifier != null ? ((Parameterized)((Object)this.baseClassifier)).getParameter(name) : ((Parameterized)((Object)this.baseRegressor)).getParameter(name);
        if (param == null) {
            throw new IllegalArgumentException("Parameter " + name + " does not exist");
        }
        return param;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        if (this.trainedClassifier == null) {
            throw new UntrainedModelException("Model has not yet been trained");
        }
        return this.trainedClassifier.classify(data);
    }

    @Override
    public double regress(DataPoint data) {
        if (this.trainedRegressor == null) {
            throw new UntrainedModelException("Model has not yet been trained");
        }
        return this.trainedRegressor.regress(data);
    }

    @Override
    public boolean supportsWeightedData() {
        return this.baseClassifier != null ? this.baseClassifier.supportsWeightedData() : this.baseRegressor.supportsWeightedData();
    }

    @Override
    public abstract ModelSearch clone();
}

