/*
 * Decompiled with CFR 0.152.
 */
package hivemall.smile.tools;

import hivemall.smile.ModelType;
import hivemall.smile.classification.DecisionTree;
import hivemall.smile.regression.RegressionTree;
import hivemall.smile.vm.StackMachine;
import hivemall.smile.vm.VMRuntimeException;
import hivemall.utils.codec.Base91;
import hivemall.utils.codec.DeflateCodec;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.io.IOUtils;
import java.io.Closeable;
import java.io.IOException;
import java.util.Arrays;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import javax.script.Bindings;
import javax.script.Compilable;
import javax.script.CompiledScript;
import javax.script.ScriptEngine;
import javax.script.ScriptEngineManager;
import javax.script.ScriptException;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.MapredContext;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.UDFType;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.mapred.JobConf;

@Description(name="tree_predict", value="_FUNC_(string modelId, int modelType, string script, array<double> features [, const boolean classification]) - Returns a prediction result of a random forest")
@UDFType(deterministic=true, stateful=false)
public final class TreePredictUDF
extends GenericUDF {
    private boolean classification;
    private PrimitiveObjectInspector modelTypeOI;
    private StringObjectInspector stringOI;
    private ListObjectInspector featureListOI;
    private PrimitiveObjectInspector featureElemOI;
    @Nullable
    private transient Evaluator evaluator;
    private boolean support_javascript_eval = true;

    public void configure(MapredContext context) {
        JobConf conf;
        String tdJarVersion;
        super.configure(context);
        if (context != null && (tdJarVersion = (conf = context.getJobConf()).get("td.jar.version")) != null) {
            this.support_javascript_eval = false;
        }
    }

    public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
        ListObjectInspector listOI;
        if (argOIs.length != 4 && argOIs.length != 5) {
            throw new UDFArgumentException("_FUNC_ takes 4 or 5 arguments");
        }
        this.modelTypeOI = HiveUtils.asIntegerOI(argOIs[1]);
        this.stringOI = HiveUtils.asStringOI(argOIs[2]);
        this.featureListOI = listOI = HiveUtils.asListOI(argOIs[3]);
        ObjectInspector elemOI = listOI.getListElementObjectInspector();
        this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI);
        boolean classification = false;
        if (argOIs.length == 5) {
            classification = HiveUtils.getConstBoolean(argOIs[4]);
        }
        this.classification = classification;
        if (classification) {
            return PrimitiveObjectInspectorFactory.writableIntObjectInspector;
        }
        return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
    }

    public Writable evaluate(@Nonnull GenericUDF.DeferredObject[] arguments) throws HiveException {
        Object arg0 = arguments[0].get();
        if (arg0 == null) {
            throw new HiveException("ModelId was null");
        }
        String modelId = arg0.toString();
        Object arg1 = arguments[1].get();
        int modelTypeId = PrimitiveObjectInspectorUtils.getInt((Object)arg1, (PrimitiveObjectInspector)this.modelTypeOI);
        ModelType modelType = ModelType.resolve(modelTypeId);
        Object arg2 = arguments[2].get();
        if (arg2 == null) {
            return null;
        }
        Text script = this.stringOI.getPrimitiveWritableObject(arg2);
        Object arg3 = arguments[3].get();
        if (arg3 == null) {
            throw new HiveException("array<double> features was null");
        }
        double[] features = HiveUtils.asDoubleArray(arg3, this.featureListOI, this.featureElemOI);
        if (this.evaluator == null) {
            this.evaluator = TreePredictUDF.getEvaluator(modelType, this.support_javascript_eval);
        }
        Writable result = this.evaluator.evaluate(modelId, modelType.isCompressed(), script, features, this.classification);
        return result;
    }

    @Nonnull
    private static Evaluator getEvaluator(@Nonnull ModelType type, boolean supportJavascriptEval) throws UDFArgumentException {
        Evaluator evaluator;
        switch (type) {
            case serialization: 
            case serialization_compressed: {
                evaluator = new JavaSerializationEvaluator();
                break;
            }
            case opscode: 
            case opscode_compressed: {
                evaluator = new StackmachineEvaluator();
                break;
            }
            case javascript: 
            case javascript_compressed: {
                if (!supportJavascriptEval) {
                    throw new UDFArgumentException("Javascript evaluation is not allowed in Treasure Data env");
                }
                evaluator = new JavascriptEvaluator();
                break;
            }
            default: {
                throw new UDFArgumentException("Unexpected model type was detected: " + (Object)((Object)type));
            }
        }
        return evaluator;
    }

    public void close() throws IOException {
        this.modelTypeOI = null;
        this.stringOI = null;
        this.featureElemOI = null;
        this.featureListOI = null;
        IOUtils.closeQuietly((Closeable)this.evaluator);
        this.evaluator = null;
    }

    public String getDisplayString(String[] children) {
        return "tree_predict(" + Arrays.toString(children) + ")";
    }

    static final class JavascriptEvaluator
    implements Evaluator {
        private final ScriptEngine scriptEngine;
        private final Compilable compilableEngine;
        private String prevModelId = null;
        private CompiledScript prevCompiled;
        private DeflateCodec codec = null;

        JavascriptEvaluator() throws UDFArgumentException {
            ScriptEngineManager manager = new ScriptEngineManager();
            ScriptEngine engine = manager.getEngineByExtension("js");
            if (!(engine instanceof Compilable)) {
                throw new UDFArgumentException("ScriptEngine was not compilable: " + engine.getFactory().getEngineName() + " version " + engine.getFactory().getEngineVersion());
            }
            this.scriptEngine = engine;
            this.compilableEngine = (Compilable)((Object)engine);
        }

        @Override
        public Writable evaluate(@Nonnull String modelId, boolean compressed, @Nonnull Text script, double[] features, boolean classification) throws HiveException {
            Object result;
            CompiledScript compiled;
            String scriptStr;
            if (compressed) {
                if (this.codec == null) {
                    this.codec = new DeflateCodec(false, true);
                }
                byte[] b = script.getBytes();
                int len = script.getLength();
                b = Base91.decode(b, 0, len);
                try {
                    b = this.codec.decompress(b);
                }
                catch (IOException e) {
                    throw new HiveException("decompression failed", (Throwable)e);
                }
                scriptStr = new String(b);
            } else {
                scriptStr = script.toString();
            }
            if (modelId.equals(this.prevModelId)) {
                compiled = this.prevCompiled;
            } else {
                try {
                    compiled = this.compilableEngine.compile(scriptStr);
                }
                catch (ScriptException e) {
                    throw new HiveException("failed to compile: \n" + script, (Throwable)e);
                }
                this.prevCompiled = compiled;
            }
            Bindings bindings = this.scriptEngine.createBindings();
            try {
                bindings.put("x", (Object)features);
                result = compiled.eval(bindings);
            }
            catch (ScriptException se) {
                throw new HiveException("failed to evaluate: \n" + script, (Throwable)se);
            }
            catch (Throwable e) {
                throw new HiveException("failed to evaluate: \n" + script, e);
            }
            finally {
                bindings.clear();
            }
            if (result == null) {
                return null;
            }
            if (!(result instanceof Number)) {
                throw new HiveException("Got an unexpected non-number result: " + result);
            }
            if (classification) {
                Number casted = (Number)result;
                return new IntWritable(casted.intValue());
            }
            Number casted = (Number)result;
            return new DoubleWritable(casted.doubleValue());
        }

        @Override
        public void close() throws IOException {
            IOUtils.closeQuietly((Closeable)this.codec);
        }
    }

    static final class StackmachineEvaluator
    implements Evaluator {
        private String prevModelId = null;
        private StackMachine prevVM = null;
        private DeflateCodec codec = null;

        StackmachineEvaluator() {
        }

        @Override
        public Writable evaluate(@Nonnull String modelId, boolean compressed, @Nonnull Text script, double[] features, boolean classification) throws HiveException {
            StackMachine vm;
            String scriptStr;
            if (compressed) {
                if (this.codec == null) {
                    this.codec = new DeflateCodec(false, true);
                }
                byte[] b = script.getBytes();
                int len = script.getLength();
                b = Base91.decode(b, 0, len);
                try {
                    b = this.codec.decompress(b);
                }
                catch (IOException e) {
                    throw new HiveException("decompression failed", (Throwable)e);
                }
                scriptStr = new String(b);
            } else {
                scriptStr = script.toString();
            }
            if (modelId.equals(this.prevModelId)) {
                vm = this.prevVM;
            } else {
                vm = new StackMachine();
                try {
                    vm.compile(scriptStr);
                }
                catch (VMRuntimeException e) {
                    throw new HiveException("failed to compile StackMachine", (Throwable)e);
                }
                this.prevModelId = modelId;
                this.prevVM = vm;
            }
            try {
                vm.eval(features);
            }
            catch (VMRuntimeException vme) {
                throw new HiveException("failed to eval StackMachine", (Throwable)vme);
            }
            catch (Throwable e) {
                throw new HiveException("failed to eval StackMachine", e);
            }
            Double result = vm.getResult();
            if (result == null) {
                return null;
            }
            if (classification) {
                return new IntWritable(result.intValue());
            }
            return new DoubleWritable(result.doubleValue());
        }

        @Override
        public void close() throws IOException {
            IOUtils.closeQuietly((Closeable)this.codec);
        }
    }

    static final class JavaSerializationEvaluator
    implements Evaluator {
        @Nullable
        private String prevModelId = null;
        private DecisionTree.Node cNode = null;
        private RegressionTree.Node rNode = null;

        JavaSerializationEvaluator() {
        }

        @Override
        public Writable evaluate(@Nonnull String modelId, boolean compressed, @Nonnull Text script, double[] features, boolean classification) throws HiveException {
            if (classification) {
                return this.evaluateClassification(modelId, compressed, script, features);
            }
            return this.evaluteRegression(modelId, compressed, script, features);
        }

        private IntWritable evaluateClassification(@Nonnull String modelId, boolean compressed, @Nonnull Text script, double[] features) throws HiveException {
            if (!modelId.equals(this.prevModelId)) {
                this.prevModelId = modelId;
                int length = script.getLength();
                byte[] b = script.getBytes();
                b = Base91.decode(b, 0, length);
                this.cNode = DecisionTree.deserializeNode(b, b.length, compressed);
            }
            assert (this.cNode != null);
            int result = this.cNode.predict(features);
            return new IntWritable(result);
        }

        private DoubleWritable evaluteRegression(@Nonnull String modelId, boolean compressed, @Nonnull Text script, double[] features) throws HiveException {
            if (!modelId.equals(this.prevModelId)) {
                this.prevModelId = modelId;
                int length = script.getLength();
                byte[] b = script.getBytes();
                b = Base91.decode(b, 0, length);
                this.rNode = RegressionTree.deserializeNode(b, b.length, compressed);
            }
            assert (this.rNode != null);
            double result = this.rNode.predict(features);
            return new DoubleWritable(result);
        }

        @Override
        public void close() throws IOException {
        }
    }

    public static interface Evaluator
    extends Closeable {
        @Nullable
        public Writable evaluate(@Nonnull String var1, boolean var2, @Nonnull Text var3, @Nonnull double[] var4, boolean var5) throws HiveException;
    }
}

