/*
 * Decompiled with CFR 0.152.
 */
package hivemall.ftvec.amplify;

import hivemall.common.RandomizedAmplifier;
import hivemall.utils.hadoop.HiveUtils;
import java.util.ArrayList;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.MapredContext;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDTF;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.mapred.JobConf;

@Description(name="rand_amplify", value="_FUNC_(const int xtimes, const int num_buffers, *) - amplify the input records x-times in map-side")
public final class RandomAmplifierUDTF
extends GenericUDTF
implements RandomizedAmplifier.DropoutListener<Object[]> {
    private boolean useSeed;
    private long seed;
    private transient ObjectInspector[] argOIs;
    private transient RandomizedAmplifier<Object[]> amplifier;

    public void configure(MapredContext mapredContext) {
        JobConf jobconf = mapredContext.getJobConf();
        String seed = jobconf.get("hivemall.amplify.seed");
        boolean bl = this.useSeed = seed != null;
        if (this.useSeed) {
            this.seed = Long.parseLong(seed);
        }
    }

    public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
        int numArgs = argOIs.length;
        if (numArgs < 3) {
            throw new UDFArgumentException("_FUNC_(int xtimes, int num_buffers, *) takes at least three arguments");
        }
        int xtimes = HiveUtils.getAsConstInt(argOIs[0]);
        if (xtimes < 1) {
            throw new UDFArgumentException("Illegal xtimes value: " + xtimes);
        }
        int numBuffers = HiveUtils.getAsConstInt(argOIs[1]);
        if (numBuffers < 2) {
            throw new UDFArgumentException("num_buffers must be greater than 2: " + numBuffers);
        }
        this.argOIs = argOIs;
        this.amplifier = this.useSeed ? new RandomizedAmplifier(numBuffers, xtimes, this.seed) : new RandomizedAmplifier(numBuffers, xtimes);
        this.amplifier.setDropoutListener(this);
        if (this.useSeed) {
            LogFactory.getLog(RandomAmplifierUDTF.class).info((Object)("rand_amplify() using seed: " + this.seed));
        }
        ArrayList<String> fieldNames = new ArrayList<String>();
        ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
        for (int i = 2; i < numArgs; ++i) {
            fieldNames.add("c" + (i - 1));
            ObjectInspector rawOI = argOIs[i];
            ObjectInspector retOI = ObjectInspectorUtils.getStandardObjectInspector((ObjectInspector)rawOI, (ObjectInspectorUtils.ObjectInspectorCopyOption)ObjectInspectorUtils.ObjectInspectorCopyOption.DEFAULT);
            fieldOIs.add(retOI);
        }
        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
    }

    public void process(Object[] args) throws HiveException {
        Object[] row = new Object[args.length - 2];
        for (int i = 2; i < args.length; ++i) {
            Object arg = args[i];
            ObjectInspector argOI = this.argOIs[i];
            row[i - 2] = ObjectInspectorUtils.copyToStandardObject((Object)arg, (ObjectInspector)argOI, (ObjectInspectorUtils.ObjectInspectorCopyOption)ObjectInspectorUtils.ObjectInspectorCopyOption.DEFAULT);
        }
        this.amplifier.add(row);
    }

    public void close() throws HiveException {
        this.amplifier.sweepAll();
        this.amplifier = null;
    }

    @Override
    public void onDrop(Object[] row) throws HiveException {
        this.forward(row);
    }
}

