/*
 * Decompiled with CFR 0.152.
 */
package hivemall.ftvec.ranking;

import hivemall.UDTFWithOptions;
import hivemall.ftvec.ranking.PerEventPositiveOnlyFeedback;
import hivemall.ftvec.ranking.PositiveOnlyFeedback;
import hivemall.utils.collections.IntArrayList;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.lang.BitUtils;
import hivemall.utils.lang.Primitives;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Random;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableIntObjectInspector;
import org.apache.hadoop.io.IntWritable;

@Description(name="bpr_sampling", value="_FUNC_(int userId, List<int> posItems [, const string options])- Returns a relation consists of <int userId, int itemId>")
public final class BprSamplingUDTF
extends UDTFWithOptions {
    private PrimitiveObjectInspector userOI;
    private ListObjectInspector itemListOI;
    private PrimitiveObjectInspector itemElemOI;
    private PositiveOnlyFeedback feedback;
    private float samplingRate;
    private boolean withoutReplacement;
    private boolean pairSampling;
    private Object[] forwardObjs;
    private IntWritable userId;
    private IntWritable posItemId;
    private IntWritable negItemId;

    @Override
    protected Options getOptions() {
        Options opts = new Options();
        opts.addOption("sampling", "sampling_rate", true, "Sampling rates of positive items [default: 1.0]");
        opts.addOption("without_replacement", false, "Do sampling without-replacement sampling [default: false]");
        opts.addOption("uniform_pair_sampling", "pair_sampling", false, "Sampling pairs uniform from feedbacks [default: false]");
        opts.addOption("maxcol", "max_itemid", true, "Max item id index [default: -1]");
        return opts;
    }

    @Override
    protected CommandLine processOptions(@Nonnull ObjectInspector[] argOIs) throws UDFArgumentException {
        CommandLine cl = null;
        int maxItemId = -1;
        float samplingRate = 1.0f;
        boolean withoutReplacement = false;
        boolean pairSampling = false;
        if (argOIs.length >= 3) {
            String args = HiveUtils.getConstString(argOIs[2]);
            cl = this.parseOptions(args);
            maxItemId = Primitives.parseInt(cl.getOptionValue("max_itemid"), maxItemId);
            withoutReplacement = cl.hasOption("without_replacement");
            pairSampling = cl.hasOption("uniform_pair_sampling");
            samplingRate = Primitives.parseFloat(cl.getOptionValue("sampling_rate"), samplingRate);
            if (withoutReplacement && samplingRate > 1.0f) {
                throw new UDFArgumentException("sampling_rate MUST be in less than or equals to 1 where without-replacement is true: " + samplingRate);
            }
        }
        this.feedback = pairSampling ? new PerEventPositiveOnlyFeedback(maxItemId) : new PositiveOnlyFeedback(maxItemId);
        this.samplingRate = samplingRate;
        this.withoutReplacement = withoutReplacement;
        this.pairSampling = pairSampling;
        return cl;
    }

    public StructObjectInspector initialize(@Nonnull ObjectInspector[] argOIs) throws UDFArgumentException {
        if (argOIs.length != 2 && argOIs.length != 3) {
            throw new UDFArgumentException("_FUNC_(int userid, array<int> itemid, [, const string options]) takes at least two arguments");
        }
        this.userOI = HiveUtils.asIntegerOI(argOIs[0]);
        this.itemListOI = HiveUtils.asListOI(argOIs[1]);
        this.itemElemOI = HiveUtils.asIntegerOI(this.itemListOI.getListElementObjectInspector());
        this.processOptions(argOIs);
        this.userId = new IntWritable();
        this.posItemId = new IntWritable();
        this.negItemId = new IntWritable();
        this.forwardObjs = new Object[]{this.userId, this.posItemId, this.negItemId};
        ArrayList<String> fieldNames = new ArrayList<String>();
        ArrayList<WritableIntObjectInspector> fieldOIs = new ArrayList<WritableIntObjectInspector>();
        fieldNames.add("user");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
        fieldNames.add("pos_item");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
        fieldNames.add("neg_item");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
    }

    public void process(@Nonnull Object[] args) throws HiveException {
        int userId = PrimitiveObjectInspectorUtils.getInt((Object)args[0], (PrimitiveObjectInspector)this.userOI);
        BprSamplingUDTF.validateIndex(userId);
        this.addFeedback(userId, args[1]);
    }

    @Nullable
    private void addFeedback(int userId, @Nonnull Object arg) throws UDFArgumentException {
        int size = this.itemListOI.getListLength(arg);
        if (size == 0) {
            return;
        }
        int maxItemId = this.feedback.getMaxItemId();
        IntArrayList posItems = new IntArrayList(size);
        for (int i = 0; i < size; ++i) {
            Object elem = this.itemListOI.getListElement(arg, i);
            if (elem == null) continue;
            int index = PrimitiveObjectInspectorUtils.getInt((Object)elem, (PrimitiveObjectInspector)this.itemElemOI);
            BprSamplingUDTF.validateIndex(index);
            maxItemId = Math.max(index, maxItemId);
            posItems.add(index);
        }
        this.feedback.addFeedback(userId, posItems);
        this.feedback.setMaxItemId(maxItemId);
    }

    public void close() throws HiveException {
        int feedbacks = this.feedback.getTotalFeedbacks();
        if (feedbacks == 0) {
            return;
        }
        int numSamples = (int)((float)feedbacks * this.samplingRate);
        if (this.pairSampling) {
            PerEventPositiveOnlyFeedback evFeedback = (PerEventPositiveOnlyFeedback)this.feedback;
            if (this.withoutReplacement) {
                this.uniformPairSamplingWithoutReplacement(evFeedback, numSamples);
            } else {
                this.uniformPairSamplingWithReplacement(evFeedback, numSamples);
            }
        } else if (this.withoutReplacement) {
            this.uniformUserSamplingWithoutReplacement(this.feedback, numSamples);
        } else {
            this.uniformUserSamplingWithReplacement(this.feedback, numSamples);
        }
    }

    private void forward(int user, int posItem, int negItem) throws HiveException {
        assert (user >= 0) : user;
        assert (posItem >= 0) : posItem;
        assert (negItem >= 0) : negItem;
        this.userId.set(user);
        this.posItemId.set(posItem);
        this.negItemId.set(negItem);
        this.forward(this.forwardObjs);
    }

    private void uniformUserSamplingWithReplacement(@Nonnull PositiveOnlyFeedback feedback, int numSamples) throws HiveException {
        int numUsers = feedback.getNumUsers();
        if (numUsers == 0) {
            return;
        }
        int maxItemId = feedback.getMaxItemId();
        if (maxItemId <= 0) {
            throw new HiveException("Invalid maxItemId: " + maxItemId);
        }
        int numItems = maxItemId + 1;
        int[] users = feedback.getUsers();
        assert (users.length == numUsers);
        Random rand = new Random(31L);
        for (int i = 0; i < numSamples; ++i) {
            int negItem;
            int user = users[rand.nextInt(numUsers)];
            IntArrayList posItems = feedback.getItems(user, true);
            assert (posItems != null) : user;
            int size = posItems.size();
            assert (size > 0) : size;
            if (size == numItems) {
                --i;
                continue;
            }
            int posItemIndex = rand.nextInt(size);
            int posItem = posItems.fastGet(posItemIndex);
            while (posItems.contains(negItem = rand.nextInt(maxItemId))) {
            }
            this.forward(user, posItem, negItem);
        }
    }

    private void uniformUserSamplingWithoutReplacement(@Nonnull PositiveOnlyFeedback feedback, int numSamples) throws HiveException {
        int numUsers = feedback.getNumUsers();
        if (numUsers == 0) {
            return;
        }
        int maxItemId = feedback.getMaxItemId();
        if (maxItemId <= 0) {
            throw new HiveException("Invalid maxItemId: " + maxItemId);
        }
        int numItems = maxItemId + 1;
        BitSet userBits = new BitSet(numUsers);
        feedback.getUsers(userBits);
        Random rand = new Random(31L);
        for (int i = 0; i < numSamples && numUsers > 0; ++i) {
            int negItem;
            int nthUser = rand.nextInt(numUsers);
            int user = BitUtils.indexOfSetBit(userBits, nthUser);
            if (user == -1) {
                throw new HiveException("Cannot find " + nthUser + "-th user among " + numUsers + " users");
            }
            IntArrayList posItems = feedback.getItems(user, true);
            assert (posItems != null) : user;
            int size = posItems.size();
            assert (size > 0) : size;
            if (size == numItems) {
                --i;
                continue;
            }
            int posItemIndex = rand.nextInt(size);
            int posItem = posItems.fastGet(posItemIndex);
            while (posItems.contains(negItem = rand.nextInt(maxItemId))) {
            }
            posItems.remove(posItemIndex);
            if (posItems.isEmpty()) {
                feedback.removeFeedback(user);
                userBits.clear(user);
                --numUsers;
            }
            this.forward(user, posItem, negItem);
        }
    }

    private void uniformPairSamplingWithReplacement(@Nonnull PerEventPositiveOnlyFeedback feedback, int numSamples) throws HiveException {
        int numFeedbacks = feedback.getTotalFeedbacks();
        if (numFeedbacks == 0) {
            return;
        }
        int maxItemId = feedback.getMaxItemId();
        if (maxItemId <= 0) {
            throw new HiveException("Invalid maxItemId: " + maxItemId);
        }
        Random rand = new Random(31L);
        for (int i = 0; i < numSamples; ++i) {
            int negItem;
            int index = rand.nextInt(numFeedbacks);
            int user = feedback.getUser(index);
            int posItem = feedback.getPositiveItem(index);
            IntArrayList posItems = feedback.getItems(user, true);
            assert (posItems != null) : user;
            while (posItems.contains(negItem = rand.nextInt(maxItemId))) {
            }
            this.forward(user, posItem, negItem);
        }
    }

    private void uniformPairSamplingWithoutReplacement(@Nonnull PerEventPositiveOnlyFeedback feedback, int numSamples) throws HiveException {
        int[] perm;
        int numFeedbacks = feedback.getTotalFeedbacks();
        if (numFeedbacks == 0) {
            return;
        }
        int maxItemId = feedback.getMaxItemId();
        if (maxItemId <= 0) {
            throw new HiveException("Invalid maxItemId: " + maxItemId);
        }
        Random rand = new Random(31L);
        for (int index : perm = feedback.getRandomIndex(rand)) {
            int negItem;
            int user = feedback.getUser(index);
            int posItem = feedback.getPositiveItem(index);
            IntArrayList posItems = feedback.getItems(user, true);
            assert (posItems != null) : user;
            while (posItems.contains(negItem = rand.nextInt(maxItemId))) {
            }
            this.forward(user, posItem, negItem);
        }
    }

    private static void validateIndex(int index) throws UDFArgumentException {
        if (index < 0) {
            throw new UDFArgumentException("Negative index is not allowed: " + index);
        }
    }
}

