/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.svm.extended;

import java.util.Collections;
import java.util.Iterator;
import java.util.Map;
import java.util.Random;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.svm.extended.OnlineAMM;
import jsat.linear.Vec;
import jsat.utils.IntList;
import jsat.utils.ListUtils;
import jsat.utils.random.RandomUtil;

public class AMM
extends OnlineAMM {
    private static final long serialVersionUID = -9198419566231617395L;
    private int subEpochs = 1;

    public AMM() {
        this(0.01);
    }

    public AMM(double lambda) {
        this(lambda, 50);
    }

    public AMM(double lambda, int classBudget) {
        super(lambda, classBudget);
        this.setEpochs(10);
    }

    public AMM(AMM toCopy) {
        super(toCopy);
        this.subEpochs = toCopy.subEpochs;
    }

    public void setSubEpochs(int subEpochs) {
        if (subEpochs < 1) {
            throw new IllegalArgumentException("subEpochs must be positive, not " + subEpochs);
        }
        this.subEpochs = subEpochs;
    }

    public int getSubEpochs() {
        return this.subEpochs;
    }

    @Override
    public void train(ClassificationDataSet dataSet, boolean parallel) {
        this.train(dataSet);
    }

    @Override
    public void train(ClassificationDataSet dataSet) {
        int changed;
        IntList randOrder = new IntList(dataSet.size());
        ListUtils.addRange(randOrder, 0, dataSet.size(), 1);
        Random rand = RandomUtil.getRandom();
        int[] Z = new int[randOrder.size()];
        this.setUp(dataSet.getCategories(), dataSet.getNumNumericalVars(), dataSet.getPredicting());
        Collections.shuffle(randOrder, rand);
        Iterator iterator = randOrder.iterator();
        while (iterator.hasNext()) {
            int i = (Integer)iterator.next();
            Z[i] = this.update(dataSet.getDataPoint(i), dataSet.getDataPointCategory(i), Integer.MIN_VALUE);
        }
        this.time = 1;
        int outerEpoch = 0;
        do {
            for (int subEpoch = 0; subEpoch < this.subEpochs; ++subEpoch) {
                Collections.shuffle(randOrder, rand);
                Iterator iterator2 = randOrder.iterator();
                while (iterator2.hasNext()) {
                    int i = (Integer)iterator2.next();
                    Z[i] = this.update(dataSet.getDataPoint(i), dataSet.getDataPointCategory(i), Z[i]);
                }
            }
            changed = 0;
            for (int i = 0; i < randOrder.size(); ++i) {
                Vec x_t = dataSet.getDataPoint(i).getNumericalValues();
                double z_t_val = 0.0;
                int z_t = -1;
                Map w_yt = (Map)this.weightMatrix.get(dataSet.getDataPointCategory(i));
                for (Map.Entry w_yt_entry : w_yt.entrySet()) {
                    Vec v = (Vec)w_yt_entry.getValue();
                    double tmp = x_t.dot(v);
                    if (!(tmp >= z_t_val)) continue;
                    z_t = (Integer)w_yt_entry.getKey();
                    z_t_val = tmp;
                }
                if (Z[i] == z_t) continue;
                ++changed;
                Z[i] = z_t;
            }
        } while (changed != 0 && ++outerEpoch < this.getEpochs());
    }

    @Override
    public AMM clone() {
        return new AMM(this);
    }
}

