/*
 * Decompiled with CFR 0.152.
 */
package hivemall.mf;

import hivemall.common.EtaEstimator;
import hivemall.mf.OnlineMatrixFactorizationUDTF;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;

@Description(name="train_mf_sgd", value="_FUNC_(INT user, INT item, FLOAT rating [, CONSTANT STRING options]) - Returns a relation consists of <int idx, array<float> Pu, array<float> Qi [, float Bu, float Bi [, float mu]]>")
public final class MatrixFactorizationSGDUDTF
extends OnlineMatrixFactorizationUDTF {
    private EtaEstimator etaEstimator;

    @Override
    protected Options getOptions() {
        Options opts = super.getOptions();
        opts.addOption("eta", true, "The initial learning rate [default: 0.001]");
        opts.addOption("eta0", true, "The initial learning rate [default 0.2]");
        opts.addOption("t", "total_steps", true, "The total number of training examples");
        opts.addOption("power_t", true, "The exponent for inverse scaling learning rate [default 0.1]");
        return opts;
    }

    @Override
    protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
        CommandLine cl = super.processOptions(argOIs);
        this.etaEstimator = EtaEstimator.get(cl);
        return cl;
    }

    @Override
    protected float eta() {
        return this.etaEstimator.eta(this.count);
    }
}

