/*
 * Decompiled with CFR 0.152.
 */
package hivemall.smile.regression;

import hivemall.smile.data.Attribute;
import hivemall.smile.utils.SmileExtUtils;
import hivemall.utils.collections.IntArrayList;
import hivemall.utils.lang.ObjectUtils;
import hivemall.utils.lang.StringUtils;
import java.io.Externalizable;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.PriorityQueue;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import smile.math.Math;
import smile.math.Random;
import smile.regression.Regression;

public final class RegressionTree
implements Regression<double[]> {
    private final Attribute[] _attributes;
    private final boolean _hasNumericType;
    private final double[] _importance;
    private final Node _root;
    private final int _maxDepth;
    private final int _minSplit;
    private final int _minLeafSize;
    private final int _numVars;
    private final int[][] _order;
    private final Random _rnd;
    private final NodeOutput _nodeOutput;

    private static void indent(StringBuilder builder, int depth) {
        for (int i = 0; i < depth; ++i) {
            builder.append("  ");
        }
    }

    public RegressionTree(@Nullable Attribute[] attributes, @Nonnull double[][] x, @Nonnull double[] y, int maxLeafs) {
        this(attributes, x, y, x[0].length, Integer.MAX_VALUE, maxLeafs, 5, 1, null, null, null);
    }

    public RegressionTree(@Nullable Attribute[] attributes, @Nonnull double[][] x, @Nonnull double[] y, int maxLeafs, @Nullable Random rand) {
        this(attributes, x, y, x[0].length, Integer.MAX_VALUE, maxLeafs, 5, 1, null, null, rand);
    }

    public RegressionTree(@Nullable Attribute[] attributes, @Nonnull double[][] x, @Nonnull double[] y, int numVars, int maxDepth, int maxLeafs, int minSplits, int minLeafSize, @Nullable int[][] order, @Nullable int[] bags, @Nullable Random rand) {
        this(attributes, x, y, numVars, maxDepth, maxLeafs, minSplits, minLeafSize, order, bags, null, rand);
    }

    public RegressionTree(@Nullable Attribute[] attributes, @Nonnull double[][] x, @Nonnull double[] y, int numVars, int maxDepth, int maxLeafs, int minSplits, int minLeafSize, @Nullable int[][] order, @Nullable int[] bags, @Nullable NodeOutput output, @Nullable Random rand) {
        int i;
        RegressionTree.checkArgument(x, y, numVars, maxDepth, maxLeafs, minSplits, minLeafSize);
        this._attributes = SmileExtUtils.attributeTypes(attributes, x);
        if (this._attributes.length != x[0].length) {
            throw new IllegalArgumentException("-attrs option is invliad: " + Arrays.toString(attributes));
        }
        this._hasNumericType = SmileExtUtils.containsNumericType(this._attributes);
        this._numVars = numVars;
        this._maxDepth = maxDepth;
        this._minSplit = minSplits;
        this._minLeafSize = minLeafSize;
        this._order = order == null ? SmileExtUtils.sort(this._attributes, x) : order;
        this._importance = new double[this._attributes.length];
        this._rnd = rand == null ? new Random() : rand;
        this._nodeOutput = output;
        int n = 0;
        double sum = 0.0;
        if (bags == null) {
            n = y.length;
            bags = new int[n];
            for (i = 0; i < n; ++i) {
                bags[i] = i;
                sum += y[i];
            }
        } else {
            n = bags.length;
            for (i = 0; i < n; ++i) {
                int index = bags[i];
                sum += y[index];
            }
        }
        this._root = new Node(sum / (double)n);
        TrainNode trainRoot = new TrainNode(this._root, x, y, bags, 1);
        if (maxLeafs == Integer.MAX_VALUE) {
            if (trainRoot.findBestSplit()) {
                trainRoot.split(null);
            }
        } else {
            TrainNode node;
            PriorityQueue<TrainNode> nextSplits = new PriorityQueue<TrainNode>();
            if (trainRoot.findBestSplit()) {
                nextSplits.add(trainRoot);
            }
            for (int leaves = 1; leaves < maxLeafs && (node = (TrainNode)nextSplits.poll()) != null; ++leaves) {
                node.split(nextSplits);
            }
        }
        if (output != null) {
            trainRoot.calculateOutput(output);
        }
    }

    private static void checkArgument(@Nonnull double[][] x, @Nonnull double[] y, int numVars, int maxDepth, int maxLeafs, int minSplits, int minLeafSize) {
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }
        if (numVars <= 0 || numVars > x[0].length) {
            throw new IllegalArgumentException("Invalid number of variables to split on at a node of the tree: " + numVars);
        }
        if (maxDepth < 2) {
            throw new IllegalArgumentException("maxDepth should be greater than 1: " + maxDepth);
        }
        if (maxLeafs < 2) {
            throw new IllegalArgumentException("Invalid maximum leaves: " + maxLeafs);
        }
        if (minSplits < 2) {
            throw new IllegalArgumentException("Invalid minimum number of samples required to split an internal node: " + minSplits);
        }
        if (minLeafSize < 1) {
            throw new IllegalArgumentException("Invalid minimum size of leaf nodes: " + minLeafSize);
        }
    }

    public double[] importance() {
        return this._importance;
    }

    @Override
    public double predict(double[] x) {
        return this._root.predict(x);
    }

    public String predictJsCodegen() {
        StringBuilder buf = new StringBuilder(1024);
        this._root.jsCodegen(buf, 0);
        return buf.toString();
    }

    public String predictOpCodegen(@Nonnull String sep) {
        ArrayList<String> opslist = new ArrayList<String>();
        this._root.opCodegen(opslist, 0);
        opslist.add("call end");
        String scripts = StringUtils.concat(opslist, sep);
        return scripts;
    }

    @Nonnull
    public byte[] predictSerCodegen(boolean compress) throws HiveException {
        try {
            if (compress) {
                return ObjectUtils.toCompressedBytes(this._root);
            }
            return ObjectUtils.toBytes(this._root);
        }
        catch (IOException ioe) {
            throw new HiveException("IOException cause while serializing DecisionTree object", (Throwable)ioe);
        }
        catch (Exception e) {
            throw new HiveException("Exception cause while serializing DecisionTree object", (Throwable)e);
        }
    }

    public static Node deserializeNode(byte[] serializedObj, int length, boolean compressed) throws HiveException {
        Node root = new Node();
        try {
            if (compressed) {
                ObjectUtils.readCompressedObject(serializedObj, 0, length, root);
            } else {
                ObjectUtils.readObject(serializedObj, length, root);
            }
        }
        catch (IOException ioe) {
            throw new HiveException("IOException cause while deserializing DecisionTree object", (Throwable)ioe);
        }
        catch (Exception e) {
            throw new HiveException("Exception cause while deserializing DecisionTree object", (Throwable)e);
        }
        return root;
    }

    private final class TrainNode
    implements Comparable<TrainNode> {
        final Node node;
        TrainNode trueChild;
        TrainNode falseChild;
        final double[][] x;
        final double[] y;
        int[] bags;
        final int depth;

        public TrainNode(Node node, double[][] x, double[] y, int[] bags, int depth) {
            this.node = node;
            this.x = x;
            this.y = y;
            this.bags = bags;
            this.depth = depth;
        }

        @Override
        public int compareTo(TrainNode a) {
            return (int)Math.signum(a.node.splitScore - this.node.splitScore);
        }

        public void calculateOutput(NodeOutput output) {
            if (this.node.trueChild == null && this.node.falseChild == null) {
                int[] samples = SmileExtUtils.bagsToSamples(this.bags);
                this.node.output = output.calculate(samples);
            } else {
                if (this.trueChild != null) {
                    this.trueChild.calculateOutput(output);
                }
                if (this.falseChild != null) {
                    this.falseChild.calculateOutput(output);
                }
            }
        }

        public boolean findBestSplit() {
            if (this.depth >= RegressionTree.this._maxDepth) {
                return false;
            }
            int numSamples = this.bags.length;
            if (numSamples <= RegressionTree.this._minSplit) {
                return false;
            }
            double sum = this.node.output * (double)numSamples;
            int p = RegressionTree.this._attributes.length;
            int[] variables = new int[p];
            for (int i = 0; i < p; ++i) {
                variables[i] = i;
            }
            if (RegressionTree.this._numVars < p) {
                SmileExtUtils.shuffle(variables, RegressionTree.this._rnd);
            }
            int[] samples = RegressionTree.this._hasNumericType ? SmileExtUtils.bagsToSamples(this.bags, this.x.length) : null;
            for (int j = 0; j < RegressionTree.this._numVars; ++j) {
                Node split = this.findBestSplit(numSamples, sum, variables[j], samples);
                if (!(split.splitScore > this.node.splitScore)) continue;
                this.node.splitFeature = split.splitFeature;
                this.node.splitFeatureType = split.splitFeatureType;
                this.node.splitValue = split.splitValue;
                this.node.splitScore = split.splitScore;
                this.node.trueChildOutput = split.trueChildOutput;
                this.node.falseChildOutput = split.falseChildOutput;
            }
            return this.node.splitFeature != -1;
        }

        private Node findBestSplit(int n, double sum, int j, @Nullable int[] samples) {
            Node split = new Node(0.0);
            if (((RegressionTree)RegressionTree.this)._attributes[j].type == Attribute.AttributeType.NOMINAL) {
                int m = RegressionTree.this._attributes[j].getSize();
                double[] trueSum = new double[m];
                int[] trueCount = new int[m];
                for (int i : this.bags) {
                    int index;
                    int n2 = index = (int)this.x[i][j];
                    trueSum[n2] = trueSum[n2] + this.y[i];
                    int n3 = index;
                    trueCount[n3] = trueCount[n3] + 1;
                }
                for (int k = 0; k < m; ++k) {
                    double falseMean;
                    double trueMean;
                    double gain;
                    double tc = trueCount[k];
                    double fc = (double)n - tc;
                    if (tc < (double)RegressionTree.this._minSplit || fc < (double)RegressionTree.this._minSplit || !((gain = tc * (trueMean = trueSum[k] / tc) * trueMean + fc * (falseMean = (sum - trueSum[k]) / fc) * falseMean - (double)n * split.output * split.output) > split.splitScore)) continue;
                    split.splitFeature = j;
                    split.splitFeatureType = Attribute.AttributeType.NOMINAL;
                    split.splitValue = k;
                    split.splitScore = gain;
                    split.trueChildOutput = trueMean;
                    split.falseChildOutput = falseMean;
                }
            } else if (((RegressionTree)RegressionTree.this)._attributes[j].type == Attribute.AttributeType.NUMERIC) {
                double trueSum = 0.0;
                int trueCount = 0;
                double prevx = Double.NaN;
                for (int i : RegressionTree.this._order[j]) {
                    int sample = samples[i];
                    if (sample <= 0) continue;
                    if (Double.isNaN(prevx) || this.x[i][j] == prevx) {
                        prevx = this.x[i][j];
                        trueSum += (double)sample * this.y[i];
                        trueCount += sample;
                        continue;
                    }
                    double falseCount = n - trueCount;
                    if (trueCount < RegressionTree.this._minSplit || falseCount < (double)RegressionTree.this._minSplit) {
                        prevx = this.x[i][j];
                        trueSum += (double)sample * this.y[i];
                        trueCount += sample;
                        continue;
                    }
                    double trueMean = trueSum / (double)trueCount;
                    double falseMean = (sum - trueSum) / falseCount;
                    double gain = (double)trueCount * trueMean * trueMean + falseCount * falseMean * falseMean - (double)n * split.output * split.output;
                    if (gain > split.splitScore) {
                        split.splitFeature = j;
                        split.splitFeatureType = Attribute.AttributeType.NUMERIC;
                        split.splitValue = (this.x[i][j] + prevx) / 2.0;
                        split.splitScore = gain;
                        split.trueChildOutput = trueMean;
                        split.falseChildOutput = falseMean;
                    }
                    prevx = this.x[i][j];
                    trueSum += (double)sample * this.y[i];
                    trueCount += sample;
                }
            } else {
                throw new IllegalStateException("Unsupported attribute type: " + (Object)((Object)((RegressionTree)RegressionTree.this)._attributes[j].type));
            }
            return split;
        }

        public boolean split(PriorityQueue<TrainNode> nextSplits) {
            if (this.node.splitFeature < 0) {
                throw new IllegalStateException("Split a node with invalid feature.");
            }
            int childBagSize = (int)((double)this.bags.length * 0.4);
            IntArrayList trueBags = new IntArrayList(childBagSize);
            IntArrayList falseBags = new IntArrayList(childBagSize);
            int tc = this.splitSamples(trueBags, falseBags);
            int fc = this.bags.length - tc;
            if (tc < RegressionTree.this._minLeafSize || fc < RegressionTree.this._minLeafSize) {
                this.node.splitFeature = -1;
                this.node.splitFeatureType = null;
                this.node.splitValue = Double.NaN;
                this.node.splitScore = 0.0;
                if (RegressionTree.this._nodeOutput == null) {
                    this.bags = null;
                }
                return false;
            }
            this.bags = null;
            this.node.trueChild = new Node(this.node.trueChildOutput);
            this.trueChild = new TrainNode(this.node.trueChild, this.x, this.y, trueBags.toArray(), this.depth + 1);
            trueBags = null;
            if (tc >= RegressionTree.this._minSplit && this.trueChild.findBestSplit()) {
                if (nextSplits != null) {
                    nextSplits.add(this.trueChild);
                } else {
                    this.trueChild.split(null);
                }
            }
            this.node.falseChild = new Node(this.node.falseChildOutput);
            this.falseChild = new TrainNode(this.node.falseChild, this.x, this.y, falseBags.toArray(), this.depth + 1);
            falseBags = null;
            if (fc >= RegressionTree.this._minSplit && this.falseChild.findBestSplit()) {
                if (nextSplits != null) {
                    nextSplits.add(this.falseChild);
                } else {
                    this.falseChild.split(null);
                }
            }
            double[] dArray = RegressionTree.this._importance;
            int n = this.node.splitFeature;
            dArray[n] = dArray[n] + this.node.splitScore;
            return true;
        }

        private int splitSamples(@Nonnull IntArrayList trueBags, @Nonnull IntArrayList falseBags) {
            int tc = 0;
            if (this.node.splitFeatureType == Attribute.AttributeType.NOMINAL) {
                int splitFeature = this.node.splitFeature;
                double splitValue = this.node.splitValue;
                for (int index : this.bags) {
                    if (this.x[index][splitFeature] == splitValue) {
                        trueBags.add(index);
                        ++tc;
                        continue;
                    }
                    falseBags.add(index);
                }
            } else if (this.node.splitFeatureType == Attribute.AttributeType.NUMERIC) {
                int splitFeature = this.node.splitFeature;
                double splitValue = this.node.splitValue;
                for (int index : this.bags) {
                    if (this.x[index][splitFeature] <= splitValue) {
                        trueBags.add(index);
                        ++tc;
                        continue;
                    }
                    falseBags.add(index);
                }
            } else {
                throw new IllegalStateException("Unsupported attribute type: " + (Object)((Object)this.node.splitFeatureType));
            }
            return tc;
        }
    }

    public static final class Node
    implements Externalizable {
        double output = 0.0;
        int splitFeature = -1;
        Attribute.AttributeType splitFeatureType = null;
        double splitValue = Double.NaN;
        double splitScore = 0.0;
        Node trueChild;
        Node falseChild;
        double trueChildOutput = 0.0;
        double falseChildOutput = 0.0;

        public Node() {
        }

        public Node(double output) {
            this.output = output;
        }

        public double predict(double[] x) {
            if (this.trueChild == null && this.falseChild == null) {
                return this.output;
            }
            if (this.splitFeatureType == Attribute.AttributeType.NOMINAL) {
                if (x[this.splitFeature] == this.splitValue) {
                    return this.trueChild.predict(x);
                }
                return this.falseChild.predict(x);
            }
            if (this.splitFeatureType == Attribute.AttributeType.NUMERIC) {
                if (x[this.splitFeature] <= this.splitValue) {
                    return this.trueChild.predict(x);
                }
                return this.falseChild.predict(x);
            }
            throw new IllegalStateException("Unsupported attribute type: " + (Object)((Object)this.splitFeatureType));
        }

        public double predict(int[] x) {
            if (this.trueChild == null && this.falseChild == null) {
                return this.output;
            }
            if (x[this.splitFeature] == (int)this.splitValue) {
                return this.trueChild.predict(x);
            }
            return this.falseChild.predict(x);
        }

        public void jsCodegen(@Nonnull StringBuilder builder, int depth) {
            if (this.trueChild == null && this.falseChild == null) {
                RegressionTree.indent(builder, depth);
                builder.append("").append(this.output).append(";\n");
            } else if (this.splitFeatureType == Attribute.AttributeType.NOMINAL) {
                RegressionTree.indent(builder, depth);
                builder.append("if(x[").append(this.splitFeature).append("] == ").append(this.splitValue).append(") {\n");
                this.trueChild.jsCodegen(builder, depth + 1);
                RegressionTree.indent(builder, depth);
                builder.append("} else {\n");
                this.falseChild.jsCodegen(builder, depth + 1);
                RegressionTree.indent(builder, depth);
                builder.append("}\n");
            } else if (this.splitFeatureType == Attribute.AttributeType.NUMERIC) {
                RegressionTree.indent(builder, depth);
                builder.append("if(x[").append(this.splitFeature).append("] <= ").append(this.splitValue).append(") {\n");
                this.trueChild.jsCodegen(builder, depth + 1);
                RegressionTree.indent(builder, depth);
                builder.append("} else  {\n");
                this.falseChild.jsCodegen(builder, depth + 1);
                RegressionTree.indent(builder, depth);
                builder.append("}\n");
            } else {
                throw new IllegalStateException("Unsupported attribute type: " + (Object)((Object)this.splitFeatureType));
            }
        }

        public int opCodegen(List<String> scripts, int depth) {
            int selfDepth = 0;
            StringBuilder buf = new StringBuilder();
            if (this.trueChild == null && this.falseChild == null) {
                buf.append("push ").append(this.output);
                scripts.add(buf.toString());
                buf.setLength(0);
                buf.append("goto last");
                scripts.add(buf.toString());
                selfDepth += 2;
            } else if (this.splitFeatureType == Attribute.AttributeType.NOMINAL) {
                buf.append("push ").append("x[").append(this.splitFeature).append("]");
                scripts.add(buf.toString());
                buf.setLength(0);
                buf.append("push ").append(this.splitValue);
                scripts.add(buf.toString());
                buf.setLength(0);
                buf.append("ifeq ");
                scripts.add(buf.toString());
                selfDepth += 3;
                int trueDepth = this.trueChild.opCodegen(scripts, depth += 3);
                selfDepth += trueDepth;
                scripts.set(depth - 1, "ifeq " + String.valueOf(depth + trueDepth));
                int falseDepth = this.falseChild.opCodegen(scripts, depth + trueDepth);
                selfDepth += falseDepth;
            } else if (this.splitFeatureType == Attribute.AttributeType.NUMERIC) {
                buf.append("push ").append("x[").append(this.splitFeature).append("]");
                scripts.add(buf.toString());
                buf.setLength(0);
                buf.append("push ").append(this.splitValue);
                scripts.add(buf.toString());
                buf.setLength(0);
                buf.append("ifle ");
                scripts.add(buf.toString());
                selfDepth += 3;
                int trueDepth = this.trueChild.opCodegen(scripts, depth += 3);
                selfDepth += trueDepth;
                scripts.set(depth - 1, "ifle " + String.valueOf(depth + trueDepth));
                int falseDepth = this.falseChild.opCodegen(scripts, depth + trueDepth);
                selfDepth += falseDepth;
            } else {
                throw new IllegalStateException("Unsupported attribute type: " + (Object)((Object)this.splitFeatureType));
            }
            return selfDepth;
        }

        @Override
        public void writeExternal(ObjectOutput out) throws IOException {
            out.writeDouble(this.output);
            out.writeInt(this.splitFeature);
            if (this.splitFeatureType == null) {
                out.writeInt(-1);
            } else {
                out.writeInt(this.splitFeatureType.getTypeId());
            }
            out.writeDouble(this.splitValue);
            if (this.trueChild == null) {
                out.writeBoolean(false);
            } else {
                out.writeBoolean(true);
                this.trueChild.writeExternal(out);
            }
            if (this.falseChild == null) {
                out.writeBoolean(false);
            } else {
                out.writeBoolean(true);
                this.falseChild.writeExternal(out);
            }
        }

        @Override
        public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
            this.output = in.readDouble();
            this.splitFeature = in.readInt();
            int typeId = in.readInt();
            this.splitFeatureType = typeId == -1 ? null : Attribute.AttributeType.resolve(typeId);
            this.splitValue = in.readDouble();
            if (in.readBoolean()) {
                this.trueChild = new Node();
                this.trueChild.readExternal(in);
            }
            if (in.readBoolean()) {
                this.falseChild = new Node();
                this.falseChild.readExternal(in);
            }
        }
    }

    public static interface NodeOutput {
        public double calculate(int[] var1);
    }
}

