/*
 * Decompiled with CFR 0.152.
 */
package hivemall.fm;

import hivemall.fm.FFMPredictionModel;
import hivemall.fm.FFMStringFeatureMapModel;
import hivemall.fm.FMHyperParameters;
import hivemall.fm.FactorizationMachineUDTF;
import hivemall.fm.Feature;
import hivemall.fm.IntFeature;
import hivemall.utils.collections.DoubleArray3D;
import hivemall.utils.collections.IntArrayList;
import hivemall.utils.hadoop.HadoopUtils;
import hivemall.utils.hadoop.Text3;
import hivemall.utils.lang.NumberUtils;
import hivemall.utils.math.MathUtils;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableStringObjectInspector;
import org.apache.hadoop.io.Text;

@Description(name="train_ffm", value="_FUNC_(array<string> x, double y [, const string options]) - Returns a prediction model")
public final class FieldAwareFactorizationMachineUDTF
extends FactorizationMachineUDTF {
    private static final Log LOG = LogFactory.getLog(FieldAwareFactorizationMachineUDTF.class);
    private boolean _FTRL;
    private boolean _globalBias;
    private boolean _linearCoeff;
    private int _numFeatures;
    private int _numFields;
    private transient FFMStringFeatureMapModel _ffmModel;
    private transient IntArrayList _fieldList;
    @Nullable
    private transient DoubleArray3D _sumVfX;

    @Override
    protected Options getOptions() {
        Options opts = super.getOptions();
        opts.addOption("w0", "global_bias", false, "Whether to include global bias term w0 [default: OFF]");
        opts.addOption("disable_wi", "no_coeff", false, "Not to include linear term [default: OFF]");
        opts.addOption("feature_hashing", true, "The number of bits for feature hashing in range [18,31] [default:21]");
        opts.addOption("num_fields", true, "The number of fields [default:1024]");
        opts.addOption("disable_adagrad", false, "Whether to use AdaGrad for tuning learning rate [default: ON]");
        opts.addOption("eta0_V", true, "The initial learning rate for V [default 1.0]");
        opts.addOption("eps", true, "A constant used in the denominator of AdaGrad [default 1.0]");
        opts.addOption("disable_ftrl", false, "Whether not to use Follow-The-Regularized-Reader [default: OFF]");
        opts.addOption("alpha", "alphaFTRL", true, "Alpha value (learning rate) of Follow-The-Regularized-Reader [default 0.1]");
        opts.addOption("beta", "betaFTRL", true, "Beta value (a learning smoothing parameter) of Follow-The-Regularized-Reader [default 1.0]");
        opts.addOption("lambda1", true, "L1 regularization value of Follow-The-Regularized-Reader that controls model Sparseness [default 0.1]");
        opts.addOption("lambda2", true, "L2 regularization value of Follow-The-Regularized-Reader [default 0.01]");
        return opts;
    }

    @Override
    protected boolean isAdaptiveRegularizationSupported() {
        return false;
    }

    @Override
    protected FMHyperParameters.FFMHyperParameters newHyperParameters() {
        return new FMHyperParameters.FFMHyperParameters();
    }

    @Override
    protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
        CommandLine cl = super.processOptions(argOIs);
        FMHyperParameters.FFMHyperParameters params = (FMHyperParameters.FFMHyperParameters)this._params;
        this._FTRL = params.useFTRL;
        this._globalBias = params.globalBias;
        this._linearCoeff = params.linearCoeff;
        this._numFeatures = params.numFeatures;
        this._numFields = params.numFields;
        return cl;
    }

    @Override
    public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
        StructObjectInspector oi = super.initialize(argOIs);
        this._fieldList = new IntArrayList();
        return oi;
    }

    @Override
    protected StructObjectInspector getOutputOI(@Nonnull FMHyperParameters params) {
        ArrayList<String> fieldNames = new ArrayList<String>();
        ArrayList<WritableStringObjectInspector> fieldOIs = new ArrayList<WritableStringObjectInspector>();
        fieldNames.add("model_id");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
        fieldNames.add("model");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
    }

    @Override
    protected FFMStringFeatureMapModel initModel(@Nonnull FMHyperParameters params) throws UDFArgumentException {
        FFMStringFeatureMapModel model;
        FMHyperParameters.FFMHyperParameters ffmParams = (FMHyperParameters.FFMHyperParameters)params;
        this._ffmModel = model = new FFMStringFeatureMapModel(ffmParams);
        return model;
    }

    @Override
    protected Feature[] parseFeatures(@Nonnull Object arg) throws HiveException {
        return Feature.parseFFMFeatures(arg, this._xOI, this._probes, this._numFeatures, this._numFields);
    }

    @Override
    public void train(@Nonnull Feature[] x, double y, boolean adaptiveRegularization) throws HiveException {
        this._ffmModel.check(x);
        try {
            this.trainTheta(x, y);
        }
        catch (Exception ex) {
            throw new HiveException("Exception caused in the " + this._t + "-th call of train()", (Throwable)ex);
        }
    }

    @Override
    protected void trainTheta(@Nonnull Feature[] x, double y) throws HiveException {
        float eta_t = this._etaEstimator.eta(this._t);
        double p = this._ffmModel.predict(x);
        double lossGrad = this._ffmModel.dloss(p, y);
        double loss = this._lossFunction.loss(p, y);
        this._cvState.incrLoss(loss);
        if (MathUtils.closeToZero(lossGrad)) {
            return;
        }
        if (this._globalBias) {
            this._ffmModel.updateW0(lossGrad, eta_t);
        }
        IntArrayList fieldList = this.getFieldList(x);
        DoubleArray3D sumVfX = this._ffmModel.sumVfX(x, fieldList, this._sumVfX);
        for (int i = 0; i < x.length; ++i) {
            boolean useV;
            Feature x_i = x[i];
            if (x_i.value == 0.0 || !(useV = this.updateWi(lossGrad, x_i, eta_t))) continue;
            int size = fieldList.size();
            for (int fieldIndex = 0; fieldIndex < size; ++fieldIndex) {
                int yField = fieldList.get(fieldIndex);
                int k = this._factors;
                for (int f = 0; f < k; ++f) {
                    double sumViX = sumVfX.get(i, fieldIndex, f);
                    this._ffmModel.updateV(lossGrad, x_i, yField, f, sumViX, this._t);
                }
            }
        }
        sumVfX.clear();
        this._sumVfX = sumVfX;
        fieldList.clear();
    }

    private boolean updateWi(double lossGrad, @Nonnull Feature xi, float eta) {
        if (!this._linearCoeff) {
            return true;
        }
        if (this._FTRL) {
            return this._ffmModel.updateWiFTRL(lossGrad, xi, eta);
        }
        this._ffmModel.updateWi(lossGrad, xi, eta);
        return true;
    }

    @Nonnull
    private IntArrayList getFieldList(@Nonnull Feature[] x) {
        for (Feature e : x) {
            short field = e.getField();
            this._fieldList.add(field);
        }
        return this._fieldList;
    }

    @Override
    protected IntFeature instantiateFeature(@Nonnull ByteBuffer input) {
        return new IntFeature(input);
    }

    @Override
    public void close() throws HiveException {
        super.close();
        this._ffmModel = null;
    }

    @Override
    protected void forwardModel() throws HiveException {
        byte[] serialized;
        this._model = null;
        this._fieldList = null;
        this._sumVfX = null;
        Text modelId = new Text();
        String taskId = HadoopUtils.getUniqueTaskIdString();
        modelId.set(taskId);
        FFMPredictionModel predModel = this._ffmModel.toPredictionModel();
        this._ffmModel = null;
        if (LOG.isInfoEnabled()) {
            LOG.info((Object)("Serializing a model '" + modelId + "'... Configured # features: " + this._numFeatures + ", Configured # fields: " + this._numFields + ", Actual # features: " + predModel.getActualNumFeatures() + ", Estimated uncompressed bytes: " + NumberUtils.prettySize(predModel.approxBytesConsumed())));
        }
        try {
            serialized = predModel.serialize();
            predModel = null;
        }
        catch (IOException e) {
            throw new HiveException("Failed to serialize a model", (Throwable)e);
        }
        if (LOG.isInfoEnabled()) {
            LOG.info((Object)("Forwarding a serialized/compressed model '" + modelId + "' of size: " + NumberUtils.prettySize(serialized.length)));
        }
        Text3 modelObj = new Text3(serialized);
        serialized = null;
        Object[] forwardObjs = new Object[]{modelId, modelObj};
        this.forward(forwardObjs);
    }
}

