/*
 * Decompiled with CFR 0.152.
 */
package hivemall.fm;

import hivemall.fm.FFMPredictionModel;
import hivemall.fm.Feature;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.lang.NumberUtils;
import java.io.IOException;
import java.util.Arrays;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.UDFType;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryArray;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector;
import org.apache.hadoop.io.Text;

@Description(name="ffm_predict", value="_FUNC_(string modelId, string model, array<string> features) returns a prediction result in double from a Field-aware Factorization Machine")
@UDFType(deterministic=true, stateful=false)
public final class FFMPredictUDF
extends GenericUDF {
    private StringObjectInspector _modelIdOI;
    private StringObjectInspector _modelOI;
    private ListObjectInspector _featureListOI;
    private DoubleWritable _result;
    @Nullable
    private String _cachedModeId;
    @Nullable
    private FFMPredictionModel _cachedModel;
    @Nullable
    private Feature[] _probes;

    public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
        if (argOIs.length != 3) {
            throw new UDFArgumentException("_FUNC_ takes 3 arguments");
        }
        this._modelIdOI = HiveUtils.asStringOI(argOIs[0]);
        this._modelOI = HiveUtils.asStringOI(argOIs[1]);
        this._featureListOI = HiveUtils.asListOI(argOIs[2]);
        this._result = new DoubleWritable();
        return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
    }

    public Object evaluate(GenericUDF.DeferredObject[] args) throws HiveException {
        Feature[] x;
        FFMPredictionModel model;
        String modelId = this._modelIdOI.getPrimitiveJavaObject(args[0].get());
        if (modelId == null) {
            throw new HiveException("modelId is not set");
        }
        if (modelId.equals(this._cachedModeId)) {
            model = this._cachedModel;
        } else {
            Text serModel = this._modelOI.getPrimitiveWritableObject(args[1].get());
            if (serModel == null) {
                throw new HiveException("Model is null for model ID: " + modelId);
            }
            byte[] b = serModel.getBytes();
            int length = serModel.getLength();
            try {
                model = FFMPredictionModel.deserialize(b, length);
                b = null;
            }
            catch (ClassNotFoundException e) {
                throw new HiveException((Throwable)e);
            }
            catch (IOException e) {
                throw new HiveException((Throwable)e);
            }
            this._cachedModeId = modelId;
            this._cachedModel = model;
        }
        int numFeatures = model.getNumFeatures();
        int numFields = model.getNumFields();
        Object arg2 = args[2].get();
        if (arg2 instanceof LazyBinaryArray) {
            arg2 = ((LazyBinaryArray)arg2).getList();
        }
        if ((x = Feature.parseFFMFeatures(arg2, this._featureListOI, this._probes, numFeatures, numFields)) == null || x.length == 0) {
            return null;
        }
        this._probes = x;
        double predicted = FFMPredictUDF.predict(x, model);
        this._result.set(predicted);
        return this._result;
    }

    private static double predict(@Nonnull Feature[] x, @Nonnull FFMPredictionModel model) throws HiveException {
        double ret = model.getW0();
        for (Feature e : x) {
            double xi = e.getValue();
            float wi = model.getW(e);
            double wx = (double)wi * xi;
            ret += wx;
        }
        int factors = model.getNumFactors();
        float[] vij = new float[factors];
        float[] vji = new float[factors];
        for (int i = 0; i < x.length; ++i) {
            Feature ei = x[i];
            double xi = ei.getValue();
            short iField = ei.getField();
            for (int j = i + 1; j < x.length; ++j) {
                Feature ej = x[j];
                double xj = ej.getValue();
                short jField = ej.getField();
                if (!model.getV(ei, jField, vij) || !model.getV(ej, iField, vij)) continue;
                for (int f = 0; f < factors; ++f) {
                    float vijf = vij[f];
                    float vjif = vji[f];
                    ret += (double)(vijf * vjif) * xi * xj;
                }
            }
        }
        if (!NumberUtils.isFinite(ret)) {
            throw new HiveException("Detected " + ret + " in ffm_predict");
        }
        return ret;
    }

    public void close() throws IOException {
        super.close();
        this._cachedModel = null;
        this._probes = null;
    }

    public String getDisplayString(String[] args) {
        return "ffm_predict(" + Arrays.toString(args) + ")";
    }
}

