/*
 * Decompiled with CFR 0.152.
 */
package hivemall.smile.classification;

import hivemall.UDTFWithOptions;
import hivemall.smile.ModelType;
import hivemall.smile.classification.DecisionTree;
import hivemall.smile.data.Attribute;
import hivemall.smile.utils.SmileExtUtils;
import hivemall.smile.utils.SmileTaskExecutor;
import hivemall.utils.codec.Base91;
import hivemall.utils.codec.DeflateCodec;
import hivemall.utils.collections.IntArrayList;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.hadoop.WritableUtils;
import hivemall.utils.io.IOUtils;
import hivemall.utils.lang.Primitives;
import hivemall.utils.lang.RandomUtils;
import java.io.Closeable;
import java.io.IOException;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.atomic.AtomicInteger;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.MapredContext;
import org.apache.hadoop.hive.ql.exec.MapredContextAccessor;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.Counters;
import org.apache.hadoop.mapred.Reporter;
import smile.math.Math;
import smile.math.Random;

@Description(name="train_randomforest_classifier", value="_FUNC_(double[] features, int label [, string options]) - Returns a relation consists of <int model_id, int model_type, string pred_model, array<double> var_importance, int oob_errors, int oob_tests>")
public final class RandomForestClassifierUDTF
extends UDTFWithOptions {
    private static final Log logger = LogFactory.getLog(RandomForestClassifierUDTF.class);
    private ListObjectInspector featureListOI;
    private PrimitiveObjectInspector featureElemOI;
    private PrimitiveObjectInspector labelOI;
    private List<double[]> featuresList;
    private IntArrayList labels;
    private int _numTrees;
    private float _numVars;
    private int _maxDepth;
    private int _maxLeafNodes;
    private int _minSamplesSplit;
    private int _minSamplesLeaf;
    private long _seed;
    private Attribute[] _attributes;
    private ModelType _outputType;
    private DecisionTree.SplitRule _splitRule;
    @Nullable
    private Reporter _progressReporter;
    @Nullable
    private Counters.Counter _treeBuildTaskCounter;

    @Override
    protected Options getOptions() {
        Options opts = new Options();
        opts.addOption("trees", "num_trees", true, "The number of trees for each task [default: 50]");
        opts.addOption("vars", "num_variables", true, "The number of random selected features [default: ceil(sqrt(x[0].length))]. int(num_variables * x[0].length) is considered if num_variable is (0,1]");
        opts.addOption("depth", "max_depth", true, "The maximum number of the tree depth [default: Integer.MAX_VALUE]");
        opts.addOption("leafs", "max_leaf_nodes", true, "The maximum number of leaf nodes [default: Integer.MAX_VALUE]");
        opts.addOption("splits", "min_split", true, "A node that has greater than or equals to `min_split` examples will split [default: 2]");
        opts.addOption("min_samples_leaf", true, "The minimum number of samples in a leaf node [default: 1]");
        opts.addOption("seed", true, "seed value in long [default: -1 (random)]");
        opts.addOption("attrs", "attribute_types", true, "Comma separated attribute types (Q for quantitative variable and C for categorical variable. e.g., [Q,C,Q,C])");
        opts.addOption("output", "output_type", true, "The output type (serialization/ser or opscode/vm or javascript/js) [default: serialization]");
        opts.addOption("rule", "split_rule", true, "Split algorithm [default: GINI, ENTROPY]");
        opts.addOption("disable_compression", false, "Whether to disable compression of the output script [default: false]");
        return opts;
    }

    @Override
    protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
        int trees = 50;
        int maxDepth = Integer.MAX_VALUE;
        int numLeafs = Integer.MAX_VALUE;
        int minSplits = 2;
        int minSamplesLeaf = 1;
        float numVars = -1.0f;
        Attribute[] attrs = null;
        long seed = -1L;
        String output = "serialization";
        DecisionTree.SplitRule splitRule = DecisionTree.SplitRule.GINI;
        boolean compress = true;
        CommandLine cl = null;
        if (argOIs.length >= 3) {
            String rawArgs = HiveUtils.getConstString(argOIs[2]);
            cl = this.parseOptions(rawArgs);
            trees = Primitives.parseInt(cl.getOptionValue("num_trees"), trees);
            if (trees < 1) {
                throw new IllegalArgumentException("Invlaid number of trees: " + trees);
            }
            numVars = Primitives.parseFloat(cl.getOptionValue("num_variables"), numVars);
            maxDepth = Primitives.parseInt(cl.getOptionValue("max_depth"), maxDepth);
            numLeafs = Primitives.parseInt(cl.getOptionValue("max_leaf_nodes"), numLeafs);
            minSplits = Primitives.parseInt(cl.getOptionValue("min_split"), minSplits);
            minSamplesLeaf = Primitives.parseInt(cl.getOptionValue("min_samples_leaf"), minSamplesLeaf);
            seed = Primitives.parseLong(cl.getOptionValue("seed"), seed);
            attrs = SmileExtUtils.resolveAttributes(cl.getOptionValue("attribute_types"));
            output = cl.getOptionValue("output", output);
            splitRule = SmileExtUtils.resolveSplitRule(cl.getOptionValue("split_rule", "GINI"));
            if (cl.hasOption("disable_compression")) {
                compress = false;
            }
        }
        this._numTrees = trees;
        this._numVars = numVars;
        this._maxDepth = maxDepth;
        this._maxLeafNodes = numLeafs;
        this._minSamplesSplit = minSplits;
        this._minSamplesLeaf = minSamplesLeaf;
        this._seed = seed;
        this._attributes = attrs;
        this._outputType = ModelType.resolve(output, compress);
        this._splitRule = splitRule;
        return cl;
    }

    public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
        if (argOIs.length != 2 && argOIs.length != 3) {
            throw new UDFArgumentException(((Object)((Object)this)).getClass().getSimpleName() + " takes 2 or 3 arguments: double[] features, int label [, const string options]: " + argOIs.length);
        }
        ListObjectInspector listOI = HiveUtils.asListOI(argOIs[0]);
        ObjectInspector elemOI = listOI.getListElementObjectInspector();
        this.featureListOI = listOI;
        this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI);
        this.labelOI = HiveUtils.asIntCompatibleOI(argOIs[1]);
        this.processOptions(argOIs);
        this.featuresList = new ArrayList<double[]>(1024);
        this.labels = new IntArrayList(1024);
        ArrayList<String> fieldNames = new ArrayList<String>(6);
        ArrayList<Object> fieldOIs = new ArrayList<Object>(6);
        fieldNames.add("model_id");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
        fieldNames.add("model_type");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
        fieldNames.add("pred_model");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
        fieldNames.add("var_importance");
        fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector((ObjectInspector)PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
        fieldNames.add("oob_errors");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
        fieldNames.add("oob_tests");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
    }

    public void process(Object[] args) throws HiveException {
        if (args[0] == null) {
            throw new HiveException("array<double> features was null");
        }
        double[] features = HiveUtils.asDoubleArray(args[0], this.featureListOI, this.featureElemOI);
        int label = PrimitiveObjectInspectorUtils.getInt((Object)args[1], (PrimitiveObjectInspector)this.labelOI);
        this.featuresList.add(features);
        this.labels.add(label);
    }

    public void close() throws HiveException {
        this._progressReporter = this.getReporter();
        this._treeBuildTaskCounter = this._progressReporter == null ? null : this._progressReporter.getCounter("hivemall.smile.RandomForestClassifier$Counter", "finishedTreeBuildTasks");
        RandomForestClassifierUDTF.reportProgress(this._progressReporter);
        int numExamples = this.featuresList.size();
        if (numExamples > 0) {
            double[][] x = (double[][])this.featuresList.toArray((T[])new double[numExamples][]);
            this.featuresList = null;
            int[] y = this.labels.toArray();
            this.labels = null;
            this.train(x, y);
        }
        this.featureListOI = null;
        this.featureElemOI = null;
        this.labelOI = null;
        this._attributes = null;
    }

    private void checkOptions() throws HiveException {
        if (this._minSamplesSplit <= 0) {
            throw new HiveException("Invalid minSamplesSplit: " + this._minSamplesSplit);
        }
        if (this._maxDepth < 1) {
            throw new HiveException("Invalid maxDepth: " + this._maxDepth);
        }
    }

    private void train(@Nonnull double[][] x, @Nonnull int[] y) throws HiveException {
        if (x.length != y.length) {
            throw new HiveException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }
        this.checkOptions();
        SmileExtUtils.shuffle(x, y, this._seed);
        int[] labels = SmileExtUtils.classLables(y);
        Attribute[] attributes = SmileExtUtils.attributeTypes(this._attributes, x);
        int numInputVars = SmileExtUtils.computeNumInputVars(this._numVars, x);
        if (logger.isInfoEnabled()) {
            logger.info((Object)("numTrees: " + this._numTrees + ", numVars: " + numInputVars + ", maxDepth: " + this._maxDepth + ", minSamplesSplit: " + this._minSamplesSplit + ", maxLeafs: " + this._maxLeafNodes + ", splitRule: " + (Object)((Object)this._splitRule) + ", seed: " + this._seed));
        }
        int numExamples = x.length;
        int[][] prediction = new int[numExamples][labels.length];
        int[][] order = SmileExtUtils.sort(attributes, x);
        AtomicInteger remainingTasks = new AtomicInteger(this._numTrees);
        ArrayList<TrainingTask> tasks = new ArrayList<TrainingTask>();
        for (int i = 0; i < this._numTrees; ++i) {
            long s = this._seed == -1L ? -1L : this._seed + (long)i;
            tasks.add(new TrainingTask(this, i, attributes, x, y, numInputVars, order, prediction, s, remainingTasks));
        }
        MapredContext mapredContext = MapredContextAccessor.get();
        SmileTaskExecutor executor = new SmileTaskExecutor(mapredContext);
        try {
            executor.run(tasks);
        }
        catch (Exception ex) {
            throw new HiveException((Throwable)ex);
        }
        finally {
            executor.shotdown();
        }
    }

    synchronized void forward(int taskId, @Nonnull Text model, @Nonnull double[] importance, int[] y, int[][] prediction, boolean lastTask) throws HiveException {
        int oobErrors = 0;
        int oobTests = 0;
        if (lastTask) {
            for (int i = 0; i < y.length; ++i) {
                int pred = Math.whichMax(prediction[i]);
                if (prediction[i][pred] <= 0) continue;
                ++oobTests;
                if (pred == y[i]) continue;
                ++oobErrors;
            }
        }
        String modelId = RandomUtils.getUUID();
        Object[] forwardObjs = new Object[]{new Text(modelId), new IntWritable(this._outputType.getId()), model, WritableUtils.toWritableList(importance), new IntWritable(oobErrors), new IntWritable(oobTests)};
        this.forward(forwardObjs);
        RandomForestClassifierUDTF.reportProgress(this._progressReporter);
        RandomForestClassifierUDTF.incrCounter(this._treeBuildTaskCounter, 1L);
        logger.info((Object)("Forwarded " + taskId + "-th DecisionTree out of " + this._numTrees));
    }

    private static final class TrainingTask
    implements Callable<Integer> {
        private final Attribute[] _attributes;
        private final double[][] _x;
        private final int[] _y;
        private final int[][] _order;
        private final int _numVars;
        private final int[][] _prediction;
        private final RandomForestClassifierUDTF _udtf;
        private final int _taskId;
        private final long _seed;
        private final AtomicInteger _remainingTasks;

        TrainingTask(RandomForestClassifierUDTF udtf, int taskId, Attribute[] attributes, double[][] x, int[] y, int numVars, int[][] order, int[][] prediction, long seed, AtomicInteger remainingTasks) {
            this._udtf = udtf;
            this._taskId = taskId;
            this._attributes = attributes;
            this._x = x;
            this._y = y;
            this._order = order;
            this._numVars = numVars;
            this._prediction = prediction;
            this._seed = seed;
            this._remainingTasks = remainingTasks;
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        @Override
        public Integer call() throws HiveException {
            long s = this._seed == -1L ? SmileExtUtils.generateSeed() : new Random(this._seed).nextLong();
            Random rnd1 = new Random(s);
            Random rnd2 = new Random(rnd1.nextLong());
            int N = this._x.length;
            int[] bags = new int[N];
            BitSet sampled = new BitSet(N);
            for (int i = 0; i < N; ++i) {
                int index;
                bags[i] = index = rnd1.nextInt(N);
                sampled.set(index);
            }
            DecisionTree tree = new DecisionTree(this._attributes, this._x, this._y, this._numVars, this._udtf._maxDepth, this._udtf._maxLeafNodes, this._udtf._minSamplesSplit, this._udtf._minSamplesLeaf, bags, this._order, this._udtf._splitRule, rnd2);
            int i = sampled.nextClearBit(0);
            while (i < N) {
                int p = tree.predict(this._x[i]);
                int[] nArray = this._prediction[i];
                synchronized (nArray) {
                    int[] nArray2 = this._prediction[i];
                    int n = p;
                    nArray2[n] = nArray2[n] + 1;
                }
                i = sampled.nextClearBit(i + 1);
            }
            Text model = TrainingTask.getModel(tree, this._udtf._outputType);
            double[] importance = tree.importance();
            int remain = this._remainingTasks.decrementAndGet();
            boolean lastTask = remain == 0;
            this._udtf.forward(this._taskId + 1, model, importance, this._y, this._prediction, lastTask);
            return remain;
        }

        /*
         * Enabled force condition propagation
         * Lifted jumps to return sites
         */
        private static Text getModel(@Nonnull DecisionTree tree, @Nonnull ModelType outputType) throws HiveException {
            switch (outputType) {
                case serialization: 
                case serialization_compressed: {
                    byte[] b = tree.predictSerCodegen(outputType.isCompressed());
                    b = Base91.encode(b);
                    return new Text(b);
                }
                case opscode: 
                case opscode_compressed: {
                    String s = tree.predictOpCodegen("; ");
                    if (!outputType.isCompressed()) return new Text(s);
                    byte[] b = s.getBytes();
                    DeflateCodec codec = new DeflateCodec(true, false);
                    try {
                        b = codec.compress(b);
                    }
                    catch (IOException e) {
                        throw new HiveException("Failed to compressing a model", (Throwable)e);
                    }
                    finally {
                        IOUtils.closeQuietly((Closeable)codec);
                    }
                    b = Base91.encode(b);
                    return new Text(b);
                }
                case javascript: 
                case javascript_compressed: {
                    String s = tree.predictJsCodegen();
                    if (!outputType.isCompressed()) return new Text(s);
                    byte[] b = s.getBytes();
                    DeflateCodec codec = new DeflateCodec(true, false);
                    try {
                        b = codec.compress(b);
                    }
                    catch (IOException e) {
                        throw new HiveException("Failed to compressing a model", (Throwable)e);
                    }
                    finally {
                        IOUtils.closeQuietly((Closeable)codec);
                    }
                    b = Base91.encode(b);
                    return new Text(b);
                }
                default: {
                    throw new HiveException("Unexpected output type: " + (Object)((Object)outputType) + ". Use javascript for the output instead");
                }
            }
        }
    }
}

