/*
 * Decompiled with CFR 0.152.
 */
package hivemall.smile.classification;

import hivemall.smile.data.Attribute;
import hivemall.smile.utils.SmileExtUtils;
import hivemall.utils.collections.IntArrayList;
import hivemall.utils.lang.ObjectUtils;
import hivemall.utils.lang.StringUtils;
import java.io.Externalizable;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.PriorityQueue;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import smile.classification.Classifier;
import smile.math.Math;
import smile.math.Random;

public final class DecisionTree
implements Classifier<double[]> {
    private final Attribute[] _attributes;
    private final boolean _hasNumericType;
    private final double[] _importance;
    private final Node _root;
    private final int _maxDepth;
    private final SplitRule _rule;
    private final int _k;
    private final int _numVars;
    private final int _minSplit;
    private final int _minLeafSize;
    private final int[][] _order;
    private final Random _rnd;

    private static void indent(StringBuilder builder, int depth) {
        for (int i = 0; i < depth; ++i) {
            builder.append("  ");
        }
    }

    private static double impurity(@Nonnull int[] count, int n, @Nonnull SplitRule rule) {
        double impurity = 0.0;
        switch (rule) {
            case GINI: {
                impurity = 1.0;
                for (int i = 0; i < count.length; ++i) {
                    if (count[i] <= 0) continue;
                    double p = (double)count[i] / (double)n;
                    impurity -= p * p;
                }
                break;
            }
            case ENTROPY: {
                for (int i = 0; i < count.length; ++i) {
                    if (count[i] <= 0) continue;
                    double p = (double)count[i] / (double)n;
                    impurity -= p * Math.log2(p);
                }
                break;
            }
            case CLASSIFICATION_ERROR: {
                impurity = 0.0;
                for (int i = 0; i < count.length; ++i) {
                    if (count[i] <= 0) continue;
                    impurity = Math.max(impurity, (double)count[i] / (double)n);
                }
                impurity = Math.abs(1.0 - impurity);
            }
        }
        return impurity;
    }

    public DecisionTree(@Nullable Attribute[] attributes, @Nonnull double[][] x, @Nonnull int[] y, int numLeafs) {
        this(attributes, x, y, x[0].length, Integer.MAX_VALUE, numLeafs, 2, 1, null, null, SplitRule.GINI, null);
    }

    public DecisionTree(@Nullable Attribute[] attributes, @Nullable double[][] x, @Nullable int[] y, int numLeafs, @Nullable Random rand) {
        this(attributes, x, y, x[0].length, Integer.MAX_VALUE, numLeafs, 2, 1, null, null, SplitRule.GINI, rand);
    }

    public DecisionTree(@Nullable Attribute[] attributes, @Nonnull double[][] x, @Nonnull int[] y, int numVars, int maxDepth, int maxLeafs, int minSplits, int minLeafSize, @Nullable int[] bags, @Nullable int[][] order, @Nonnull SplitRule rule, @Nullable Random rand) {
        int i;
        DecisionTree.checkArgument(x, y, numVars, maxDepth, maxLeafs, minSplits, minLeafSize);
        this._k = Math.max(y) + 1;
        if (this._k < 2) {
            throw new IllegalArgumentException("Only one class or negative class labels.");
        }
        this._attributes = SmileExtUtils.attributeTypes(attributes, x);
        if (attributes.length != x[0].length) {
            throw new IllegalArgumentException("-attrs option is invliad: " + Arrays.toString(attributes));
        }
        this._hasNumericType = SmileExtUtils.containsNumericType(this._attributes);
        this._numVars = numVars;
        this._maxDepth = maxDepth;
        this._minSplit = minSplits;
        this._minLeafSize = minLeafSize;
        this._rule = rule;
        this._order = order == null ? SmileExtUtils.sort(this._attributes, x) : order;
        this._importance = new double[this._attributes.length];
        this._rnd = rand == null ? new Random() : rand;
        int n = y.length;
        int[] count = new int[this._k];
        if (bags == null) {
            bags = new int[n];
            for (i = 0; i < n; ++i) {
                bags[i] = i;
                int n2 = y[i];
                count[n2] = count[n2] + 1;
            }
        } else {
            for (i = 0; i < n; ++i) {
                int index = bags[i];
                int n3 = y[index];
                count[n3] = count[n3] + 1;
            }
        }
        this._root = new Node(Math.whichMax(count));
        TrainNode trainRoot = new TrainNode(this._root, x, y, bags, 1);
        if (maxLeafs == Integer.MAX_VALUE) {
            if (trainRoot.findBestSplit()) {
                trainRoot.split(null);
            }
        } else {
            TrainNode parent;
            PriorityQueue<TrainNode> nextSplits = new PriorityQueue<TrainNode>();
            if (trainRoot.findBestSplit()) {
                nextSplits.add(trainRoot);
            }
            for (int leaves = 1; leaves < maxLeafs && (parent = (TrainNode)nextSplits.poll()) != null; ++leaves) {
                parent.split(nextSplits);
            }
        }
    }

    private static void checkArgument(@Nonnull double[][] x, @Nonnull int[] y, int numVars, int maxDepth, int maxLeafs, int minSplits, int minLeafSize) {
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }
        if (numVars <= 0 || numVars > x[0].length) {
            throw new IllegalArgumentException("Invalid number of variables to split on at a node of the tree: " + numVars);
        }
        if (maxDepth < 2) {
            throw new IllegalArgumentException("maxDepth should be greater than 1: " + maxDepth);
        }
        if (maxLeafs < 2) {
            throw new IllegalArgumentException("Invalid maximum leaves: " + maxLeafs);
        }
        if (minSplits < 2) {
            throw new IllegalArgumentException("Invalid minimum number of samples required to split an internal node: " + minSplits);
        }
        if (minLeafSize < 1) {
            throw new IllegalArgumentException("Invalid minimum size of leaf nodes: " + minLeafSize);
        }
    }

    public double[] importance() {
        return this._importance;
    }

    @Override
    public int predict(double[] x) {
        return this._root.predict(x);
    }

    @Override
    public int predict(double[] x, double[] posteriori) {
        throw new UnsupportedOperationException("Not supported.");
    }

    public String predictJsCodegen() {
        StringBuilder buf = new StringBuilder(1024);
        this._root.jsCodegen(buf, 0);
        return buf.toString();
    }

    public String predictOpCodegen(String sep) {
        ArrayList<String> opslist = new ArrayList<String>();
        this._root.opCodegen(opslist, 0);
        opslist.add("call end");
        String scripts = StringUtils.concat(opslist, sep);
        return scripts;
    }

    @Nonnull
    public byte[] predictSerCodegen(boolean compress) throws HiveException {
        try {
            if (compress) {
                return ObjectUtils.toCompressedBytes(this._root);
            }
            return ObjectUtils.toBytes(this._root);
        }
        catch (IOException ioe) {
            throw new HiveException("IOException cause while serializing DecisionTree object", (Throwable)ioe);
        }
        catch (Exception e) {
            throw new HiveException("Exception cause while serializing DecisionTree object", (Throwable)e);
        }
    }

    public static Node deserializeNode(byte[] serializedObj, int length, boolean compressed) throws HiveException {
        Node root = new Node();
        try {
            if (compressed) {
                ObjectUtils.readCompressedObject(serializedObj, 0, length, root);
            } else {
                ObjectUtils.readObject(serializedObj, length, root);
            }
        }
        catch (IOException ioe) {
            throw new HiveException("IOException cause while deserializing DecisionTree object", (Throwable)ioe);
        }
        catch (Exception e) {
            throw new HiveException("Exception cause while deserializing DecisionTree object", (Throwable)e);
        }
        return root;
    }

    public String toString() {
        return this._root == null ? "" : this.predictJsCodegen();
    }

    private final class TrainNode
    implements Comparable<TrainNode> {
        final Node node;
        final double[][] x;
        final int[] y;
        int[] bags;
        final int depth;

        public TrainNode(Node node, double[][] x, int[] y, int[] bags, int depth) {
            this.node = node;
            this.x = x;
            this.y = y;
            this.bags = bags;
            this.depth = depth;
        }

        @Override
        public int compareTo(TrainNode a) {
            return (int)Math.signum(a.node.splitScore - this.node.splitScore);
        }

        public boolean findBestSplit() {
            if (this.depth >= DecisionTree.this._maxDepth) {
                return false;
            }
            int numSamples = this.bags.length;
            if (numSamples <= DecisionTree.this._minSplit) {
                return false;
            }
            int[] count = new int[DecisionTree.this._k];
            boolean pure = this.sampleCount(count);
            if (pure) {
                return false;
            }
            double impurity = DecisionTree.impurity(count, numSamples, DecisionTree.this._rule);
            int p = DecisionTree.this._attributes.length;
            int[] variableIndex = new int[p];
            for (int i = 0; i < p; ++i) {
                variableIndex[i] = i;
            }
            if (DecisionTree.this._numVars < p) {
                SmileExtUtils.shuffle(variableIndex, DecisionTree.this._rnd);
            }
            int[] samples = DecisionTree.this._hasNumericType ? SmileExtUtils.bagsToSamples(this.bags, this.x.length) : null;
            int[] falseCount = new int[DecisionTree.this._k];
            for (int j = 0; j < DecisionTree.this._numVars; ++j) {
                Node split = this.findBestSplit(numSamples, count, falseCount, impurity, variableIndex[j], samples);
                if (!(split.splitScore > this.node.splitScore)) continue;
                this.node.splitFeature = split.splitFeature;
                this.node.splitFeatureType = split.splitFeatureType;
                this.node.splitValue = split.splitValue;
                this.node.splitScore = split.splitScore;
                this.node.trueChildOutput = split.trueChildOutput;
                this.node.falseChildOutput = split.falseChildOutput;
            }
            return this.node.splitFeature != -1;
        }

        private boolean sampleCount(@Nonnull int[] count) {
            int label = -1;
            boolean pure = true;
            for (int i = 0; i < this.bags.length; ++i) {
                int y_i;
                int index = this.bags[i];
                int n = y_i = this.y[index];
                count[n] = count[n] + 1;
                if (label == -1) {
                    label = y_i;
                    continue;
                }
                if (y_i == label) continue;
                pure = false;
            }
            return pure;
        }

        private Node findBestSplit(int n, int[] count, int[] falseCount, double impurity, int j, @Nullable int[] samples) {
            Node splitNode = new Node();
            if (((DecisionTree)DecisionTree.this)._attributes[j].type == Attribute.AttributeType.NOMINAL) {
                int m = DecisionTree.this._attributes[j].getSize();
                int[][] trueCount = new int[m][DecisionTree.this._k];
                for (int index : this.bags) {
                    int x_ij = (int)this.x[index][j];
                    int[] nArray = trueCount[x_ij];
                    int n2 = this.y[index];
                    nArray[n2] = nArray[n2] + 1;
                }
                for (int l = 0; l < m; ++l) {
                    int tc = Math.sum(trueCount[l]);
                    int fc = n - tc;
                    if (tc < DecisionTree.this._minSplit || fc < DecisionTree.this._minSplit) continue;
                    for (int q = 0; q < DecisionTree.this._k; ++q) {
                        falseCount[q] = count[q] - trueCount[l][q];
                    }
                    double gain = impurity - (double)tc / (double)n * DecisionTree.impurity(trueCount[l], tc, DecisionTree.this._rule) - (double)fc / (double)n * DecisionTree.impurity(falseCount, fc, DecisionTree.this._rule);
                    if (!(gain > splitNode.splitScore)) continue;
                    splitNode.splitFeature = j;
                    splitNode.splitFeatureType = Attribute.AttributeType.NOMINAL;
                    splitNode.splitValue = l;
                    splitNode.splitScore = gain;
                    splitNode.trueChildOutput = Math.whichMax(trueCount[l]);
                    splitNode.falseChildOutput = Math.whichMax(falseCount);
                }
            } else if (((DecisionTree)DecisionTree.this)._attributes[j].type == Attribute.AttributeType.NUMERIC) {
                int[] trueCount = new int[DecisionTree.this._k];
                double prevx = Double.NaN;
                int prevy = -1;
                assert (samples != null);
                for (int i : DecisionTree.this._order[j]) {
                    int sample = samples[i];
                    if (sample <= 0) continue;
                    double x_ij = this.x[i][j];
                    int y_i = this.y[i];
                    if (Double.isNaN(prevx) || x_ij == prevx || y_i == prevy) {
                        prevx = x_ij;
                        prevy = y_i;
                        int n3 = y_i;
                        trueCount[n3] = trueCount[n3] + sample;
                        continue;
                    }
                    int tc = Math.sum(trueCount);
                    int fc = n - tc;
                    if (tc < DecisionTree.this._minSplit || fc < DecisionTree.this._minSplit) {
                        prevx = x_ij;
                        prevy = y_i;
                        int n4 = y_i;
                        trueCount[n4] = trueCount[n4] + sample;
                        continue;
                    }
                    for (int l = 0; l < DecisionTree.this._k; ++l) {
                        falseCount[l] = count[l] - trueCount[l];
                    }
                    double gain = impurity - (double)tc / (double)n * DecisionTree.impurity(trueCount, tc, DecisionTree.this._rule) - (double)fc / (double)n * DecisionTree.impurity(falseCount, fc, DecisionTree.this._rule);
                    if (gain > splitNode.splitScore) {
                        splitNode.splitFeature = j;
                        splitNode.splitFeatureType = Attribute.AttributeType.NUMERIC;
                        splitNode.splitValue = (x_ij + prevx) / 2.0;
                        splitNode.splitScore = gain;
                        splitNode.trueChildOutput = Math.whichMax(trueCount);
                        splitNode.falseChildOutput = Math.whichMax(falseCount);
                    }
                    prevx = x_ij;
                    prevy = y_i;
                    int n5 = y_i;
                    trueCount[n5] = trueCount[n5] + sample;
                }
            } else {
                throw new IllegalStateException("Unsupported attribute type: " + (Object)((Object)((DecisionTree)DecisionTree.this)._attributes[j].type));
            }
            return splitNode;
        }

        public boolean split(@Nullable PriorityQueue<TrainNode> nextSplits) {
            if (this.node.splitFeature < 0) {
                throw new IllegalStateException("Split a node with invalid feature.");
            }
            int childBagSize = (int)((double)this.bags.length * 0.4);
            IntArrayList trueBags = new IntArrayList(childBagSize);
            IntArrayList falseBags = new IntArrayList(childBagSize);
            int tc = this.splitSamples(trueBags, falseBags);
            int fc = this.bags.length - tc;
            this.bags = null;
            if (tc < DecisionTree.this._minLeafSize || fc < DecisionTree.this._minLeafSize) {
                this.node.splitFeature = -1;
                this.node.splitFeatureType = null;
                this.node.splitValue = Double.NaN;
                this.node.splitScore = 0.0;
                return false;
            }
            this.node.trueChild = new Node(this.node.trueChildOutput);
            TrainNode trueChild = new TrainNode(this.node.trueChild, this.x, this.y, trueBags.toArray(), this.depth + 1);
            trueBags = null;
            if (tc >= DecisionTree.this._minSplit && trueChild.findBestSplit()) {
                if (nextSplits != null) {
                    nextSplits.add(trueChild);
                } else {
                    trueChild.split(null);
                }
            }
            this.node.falseChild = new Node(this.node.falseChildOutput);
            TrainNode falseChild = new TrainNode(this.node.falseChild, this.x, this.y, falseBags.toArray(), this.depth + 1);
            falseBags = null;
            if (fc >= DecisionTree.this._minSplit && falseChild.findBestSplit()) {
                if (nextSplits != null) {
                    nextSplits.add(falseChild);
                } else {
                    falseChild.split(null);
                }
            }
            double[] dArray = DecisionTree.this._importance;
            int n = this.node.splitFeature;
            dArray[n] = dArray[n] + this.node.splitScore;
            return true;
        }

        private int splitSamples(@Nonnull IntArrayList trueBags, @Nonnull IntArrayList falseBags) {
            int tc = 0;
            if (this.node.splitFeatureType == Attribute.AttributeType.NOMINAL) {
                int splitFeature = this.node.splitFeature;
                double splitValue = this.node.splitValue;
                for (int index : this.bags) {
                    if (this.x[index][splitFeature] == splitValue) {
                        trueBags.add(index);
                        ++tc;
                        continue;
                    }
                    falseBags.add(index);
                }
            } else if (this.node.splitFeatureType == Attribute.AttributeType.NUMERIC) {
                int splitFeature = this.node.splitFeature;
                double splitValue = this.node.splitValue;
                for (int index : this.bags) {
                    if (this.x[index][splitFeature] <= splitValue) {
                        trueBags.add(index);
                        ++tc;
                        continue;
                    }
                    falseBags.add(index);
                }
            } else {
                throw new IllegalStateException("Unsupported attribute type: " + (Object)((Object)this.node.splitFeatureType));
            }
            return tc;
        }
    }

    public static final class Node
    implements Externalizable {
        int output = -1;
        int splitFeature = -1;
        Attribute.AttributeType splitFeatureType = null;
        double splitValue = Double.NaN;
        double splitScore = 0.0;
        Node trueChild = null;
        Node falseChild = null;
        int trueChildOutput = -1;
        int falseChildOutput = -1;

        public Node() {
        }

        public Node(int output) {
            this.output = output;
        }

        public int predict(double[] x) {
            if (this.trueChild == null && this.falseChild == null) {
                return this.output;
            }
            if (this.splitFeatureType == Attribute.AttributeType.NOMINAL) {
                if (x[this.splitFeature] == this.splitValue) {
                    return this.trueChild.predict(x);
                }
                return this.falseChild.predict(x);
            }
            if (this.splitFeatureType == Attribute.AttributeType.NUMERIC) {
                if (x[this.splitFeature] <= this.splitValue) {
                    return this.trueChild.predict(x);
                }
                return this.falseChild.predict(x);
            }
            throw new IllegalStateException("Unsupported attribute type: " + (Object)((Object)this.splitFeatureType));
        }

        public void jsCodegen(@Nonnull StringBuilder builder, int depth) {
            if (this.trueChild == null && this.falseChild == null) {
                DecisionTree.indent(builder, depth);
                builder.append("").append(this.output).append(";\n");
            } else if (this.splitFeatureType == Attribute.AttributeType.NOMINAL) {
                DecisionTree.indent(builder, depth);
                builder.append("if(x[").append(this.splitFeature).append("] == ").append(this.splitValue).append(") {\n");
                this.trueChild.jsCodegen(builder, depth + 1);
                DecisionTree.indent(builder, depth);
                builder.append("} else {\n");
                this.falseChild.jsCodegen(builder, depth + 1);
                DecisionTree.indent(builder, depth);
                builder.append("}\n");
            } else if (this.splitFeatureType == Attribute.AttributeType.NUMERIC) {
                DecisionTree.indent(builder, depth);
                builder.append("if(x[").append(this.splitFeature).append("] <= ").append(this.splitValue).append(") {\n");
                this.trueChild.jsCodegen(builder, depth + 1);
                DecisionTree.indent(builder, depth);
                builder.append("} else  {\n");
                this.falseChild.jsCodegen(builder, depth + 1);
                DecisionTree.indent(builder, depth);
                builder.append("}\n");
            } else {
                throw new IllegalStateException("Unsupported attribute type: " + (Object)((Object)this.splitFeatureType));
            }
        }

        public int opCodegen(List<String> scripts, int depth) {
            int selfDepth = 0;
            StringBuilder buf = new StringBuilder();
            if (this.trueChild == null && this.falseChild == null) {
                buf.append("push ").append(this.output);
                scripts.add(buf.toString());
                buf.setLength(0);
                buf.append("goto last");
                scripts.add(buf.toString());
                selfDepth += 2;
            } else if (this.splitFeatureType == Attribute.AttributeType.NOMINAL) {
                buf.append("push ").append("x[").append(this.splitFeature).append("]");
                scripts.add(buf.toString());
                buf.setLength(0);
                buf.append("push ").append(this.splitValue);
                scripts.add(buf.toString());
                buf.setLength(0);
                buf.append("ifeq ");
                scripts.add(buf.toString());
                selfDepth += 3;
                int trueDepth = this.trueChild.opCodegen(scripts, depth += 3);
                selfDepth += trueDepth;
                scripts.set(depth - 1, "ifeq " + String.valueOf(depth + trueDepth));
                int falseDepth = this.falseChild.opCodegen(scripts, depth + trueDepth);
                selfDepth += falseDepth;
            } else if (this.splitFeatureType == Attribute.AttributeType.NUMERIC) {
                buf.append("push ").append("x[").append(this.splitFeature).append("]");
                scripts.add(buf.toString());
                buf.setLength(0);
                buf.append("push ").append(this.splitValue);
                scripts.add(buf.toString());
                buf.setLength(0);
                buf.append("ifle ");
                scripts.add(buf.toString());
                selfDepth += 3;
                int trueDepth = this.trueChild.opCodegen(scripts, depth += 3);
                selfDepth += trueDepth;
                scripts.set(depth - 1, "ifle " + String.valueOf(depth + trueDepth));
                int falseDepth = this.falseChild.opCodegen(scripts, depth + trueDepth);
                selfDepth += falseDepth;
            } else {
                throw new IllegalStateException("Unsupported attribute type: " + (Object)((Object)this.splitFeatureType));
            }
            return selfDepth;
        }

        @Override
        public void writeExternal(ObjectOutput out) throws IOException {
            out.writeInt(this.output);
            out.writeInt(this.splitFeature);
            if (this.splitFeatureType == null) {
                out.writeInt(-1);
            } else {
                out.writeInt(this.splitFeatureType.getTypeId());
            }
            out.writeDouble(this.splitValue);
            if (this.trueChild == null) {
                out.writeBoolean(false);
            } else {
                out.writeBoolean(true);
                this.trueChild.writeExternal(out);
            }
            if (this.falseChild == null) {
                out.writeBoolean(false);
            } else {
                out.writeBoolean(true);
                this.falseChild.writeExternal(out);
            }
        }

        @Override
        public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
            this.output = in.readInt();
            this.splitFeature = in.readInt();
            int typeId = in.readInt();
            this.splitFeatureType = typeId == -1 ? null : Attribute.AttributeType.resolve(typeId);
            this.splitValue = in.readDouble();
            if (in.readBoolean()) {
                this.trueChild = new Node();
                this.trueChild.readExternal(in);
            }
            if (in.readBoolean()) {
                this.falseChild = new Node();
                this.falseChild.readExternal(in);
            }
        }
    }

    public static enum SplitRule {
        GINI,
        ENTROPY,
        CLASSIFICATION_ERROR;

    }
}

