/*
 * Decompiled with CFR 0.152.
 */
package hivemall.smile.utils;

import hivemall.smile.classification.DecisionTree;
import hivemall.smile.data.Attribute;
import java.util.Arrays;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import smile.data.Attribute;
import smile.math.Math;
import smile.math.Random;
import smile.sort.QuickSort;

public final class SmileExtUtils {
    private SmileExtUtils() {
    }

    @Nullable
    public static hivemall.smile.data.Attribute[] resolveAttributes(@Nullable String opt) throws UDFArgumentException {
        if (opt == null) {
            return null;
        }
        String[] opts = opt.split(",");
        int size = opts.length;
        hivemall.smile.data.Attribute[] attr = new hivemall.smile.data.Attribute[size];
        for (int i = 0; i < size; ++i) {
            String type = opts[i];
            if ("Q".equals(type)) {
                attr[i] = new Attribute.NumericAttribute(i);
                continue;
            }
            if ("C".equals(type)) {
                attr[i] = new Attribute.NominalAttribute(i);
                continue;
            }
            throw new UDFArgumentException("Unexpected type: " + type);
        }
        return attr;
    }

    @Nonnull
    public static hivemall.smile.data.Attribute[] attributeTypes(@Nullable hivemall.smile.data.Attribute[] attributes, @Nonnull double[][] x) {
        if (attributes == null) {
            int p = x[0].length;
            attributes = new hivemall.smile.data.Attribute[p];
            for (int i = 0; i < p; ++i) {
                attributes[i] = new Attribute.NumericAttribute(i);
            }
        } else {
            int size = attributes.length;
            for (int j = 0; j < size; ++j) {
                hivemall.smile.data.Attribute attr = attributes[j];
                if (attr.type != Attribute.AttributeType.NOMINAL || attr.getSize() != -1) continue;
                int max_x = 0;
                for (int i = 0; i < x.length; ++i) {
                    int x_ij = (int)x[i][j];
                    if (x_ij <= max_x) continue;
                    max_x = x_ij;
                }
                attr.setSize(max_x + 1);
            }
        }
        return attributes;
    }

    @Nonnull
    public static hivemall.smile.data.Attribute[] convertAttributeTypes(@Nonnull Attribute[] original) {
        int size = original.length;
        hivemall.smile.data.Attribute[] dst = new hivemall.smile.data.Attribute[size];
        block4: for (int i = 0; i < size; ++i) {
            Attribute o = original[i];
            switch (o.type) {
                case NOMINAL: {
                    dst[i] = new Attribute.NominalAttribute(i);
                    continue block4;
                }
                case NUMERIC: {
                    dst[i] = new Attribute.NumericAttribute(i);
                    continue block4;
                }
                default: {
                    throw new UnsupportedOperationException("Unsupported type: " + (Object)((Object)o.type));
                }
            }
        }
        return dst;
    }

    @Nonnull
    public static int[][] sort(@Nonnull hivemall.smile.data.Attribute[] attributes, @Nonnull double[][] x) {
        int n = x.length;
        int p = x[0].length;
        double[] a = new double[n];
        int[][] index = new int[p][];
        for (int j = 0; j < p; ++j) {
            if (attributes[j].type != Attribute.AttributeType.NUMERIC) continue;
            for (int i = 0; i < n; ++i) {
                a[i] = x[i][j];
            }
            index[j] = QuickSort.sort(a);
        }
        return index;
    }

    @Nonnull
    public static int[] classLables(@Nonnull int[] y) throws HiveException {
        int[] labels = Math.unique(y);
        Arrays.sort(labels);
        if (labels.length < 2) {
            throw new HiveException("Only one class.");
        }
        for (int i = 0; i < labels.length; ++i) {
            if (labels[i] < 0) {
                throw new HiveException("Negative class label: " + labels[i]);
            }
            if (i <= 0 || labels[i] - labels[i - 1] <= 1) continue;
            throw new HiveException("Missing class: " + labels[i] + 1);
        }
        return labels;
    }

    @Nonnull
    public static DecisionTree.SplitRule resolveSplitRule(@Nullable String ruleName) {
        if ("gini".equalsIgnoreCase(ruleName)) {
            return DecisionTree.SplitRule.GINI;
        }
        if ("entropy".equalsIgnoreCase(ruleName)) {
            return DecisionTree.SplitRule.ENTROPY;
        }
        if ("classification_error".equalsIgnoreCase(ruleName)) {
            return DecisionTree.SplitRule.CLASSIFICATION_ERROR;
        }
        return DecisionTree.SplitRule.GINI;
    }

    public static int computeNumInputVars(float numVars, double[][] x) {
        int numInputVars;
        if (numVars <= 0.0f) {
            int dims = x[0].length;
            numInputVars = (int)java.lang.Math.ceil(java.lang.Math.sqrt(dims));
        } else {
            numInputVars = numVars > 0.0f && numVars <= 1.0f ? (int)(numVars * (float)x[0].length) : (int)numVars;
        }
        return numInputVars;
    }

    public static long generateSeed() {
        return Thread.currentThread().getId() * System.nanoTime();
    }

    public static void shuffle(@Nonnull int[] x, @Nonnull Random rnd) {
        for (int i = x.length; i > 1; --i) {
            int j = rnd.nextInt(i);
            SmileExtUtils.swap(x, i - 1, j);
        }
    }

    public static void shuffle(@Nonnull double[][] x, int[] y, @Nonnull long seed) {
        if (x.length != y.length) {
            throw new IllegalArgumentException("x.length (" + x.length + ") != y.length (" + y.length + ')');
        }
        if (seed == -1L) {
            seed = SmileExtUtils.generateSeed();
        }
        Random rnd = new Random(seed);
        for (int i = x.length; i > 1; --i) {
            int j = rnd.nextInt(i);
            SmileExtUtils.swap(x, i - 1, j);
            SmileExtUtils.swap(y, i - 1, j);
        }
    }

    public static void shuffle(@Nonnull double[][] x, double[] y, @Nonnull long seed) {
        if (x.length != y.length) {
            throw new IllegalArgumentException("x.length (" + x.length + ") != y.length (" + y.length + ')');
        }
        if (seed == -1L) {
            seed = SmileExtUtils.generateSeed();
        }
        Random rnd = new Random(seed);
        for (int i = x.length; i > 1; --i) {
            int j = rnd.nextInt(i);
            SmileExtUtils.swap(x, i - 1, j);
            SmileExtUtils.swap(y, i - 1, j);
        }
    }

    private static void swap(int[] x, int i, int j) {
        int s = x[i];
        x[i] = x[j];
        x[j] = s;
    }

    private static void swap(double[] x, int i, int j) {
        double s = x[i];
        x[i] = x[j];
        x[j] = s;
    }

    private static void swap(double[][] x, int i, int j) {
        double[] s = x[i];
        x[i] = x[j];
        x[j] = s;
    }

    @Nonnull
    public static int[] bagsToSamples(@Nonnull int[] bags) {
        int maxIndex = -1;
        for (int e : bags) {
            if (e <= maxIndex) continue;
            maxIndex = e;
        }
        return SmileExtUtils.bagsToSamples(bags, maxIndex + 1);
    }

    @Nonnull
    public static int[] bagsToSamples(@Nonnull int[] bags, int samplesLength) {
        int[] samples = new int[samplesLength];
        for (int n : bags) {
            samples[n] = samples[n] + 1;
        }
        return samples;
    }

    public static boolean containsNumericType(@Nonnull hivemall.smile.data.Attribute[] attributes) {
        for (hivemall.smile.data.Attribute attr : attributes) {
            if (attr.type != Attribute.AttributeType.NUMERIC) continue;
            return true;
        }
        return false;
    }
}

