/*
 * Decompiled with CFR 0.152.
 */
package hivemall.classifier.multiclass;

import hivemall.classifier.multiclass.MulticlassOnlineClassifierUDTF;
import hivemall.model.FeatureValue;
import hivemall.model.PredictionResult;
import javax.annotation.Nonnull;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;

@Description(name="train_multiclass_perceptron", value="_FUNC_(list<string|int|bigint> features, {int|string} label [, const string options]) - Returns a relation consists of <{int|string} label, {string|int|bigint} feature, float weight>", extended="Build a prediction model by Perceptron multiclass classifier")
public final class MulticlassPerceptronUDTF
extends MulticlassOnlineClassifierUDTF {
    @Override
    public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
        int numArgs = argOIs.length;
        if (numArgs != 2 && numArgs != 3) {
            throw new UDFArgumentException("MulticlassPerceptronUDTF takes 2 or 3 arguments: List<Text|Int|BitInt> features, {Int|Text} label [, constant text options]");
        }
        return super.initialize(argOIs);
    }

    @Override
    protected void train(@Nonnull FeatureValue[] features, @Nonnull Object actual_label) {
        PredictionResult predicted = this.classify(features);
        Object predicted_label = predicted.getLabel();
        if (!actual_label.equals(predicted_label)) {
            this.update(features, 1.0f, actual_label, predicted_label);
        }
    }
}

