/*
 * Decompiled with CFR 0.152.
 */
package hivemall.evaluation;

import hivemall.evaluation.BinaryResponsesMeasures;
import hivemall.utils.hadoop.HiveUtils;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import javax.annotation.Nonnull;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.LongWritable;

@Description(name="ndcg", value="_FUNC_(array rankItems, array correctItems [, const boolean binaryResponses = true]) - Returns nDCG")
public final class NDCGUDAF
extends AbstractGenericUDAFResolver {
    private NDCGUDAF() {
    }

    public GenericUDAFEvaluator getEvaluator(@Nonnull TypeInfo[] typeInfo) throws SemanticException {
        if (typeInfo.length != 2 && typeInfo.length != 3) {
            throw new UDFArgumentTypeException(typeInfo.length - 1, "_FUNC_ takes two or three arguments");
        }
        boolean binaryResponses = true;
        if (typeInfo.length == 3 && !(binaryResponses = HiveUtils.isBooleanTypeInfo(typeInfo[2]))) {
            throw new UDFArgumentException("nDCG computation for Graded Responses is not supported yet");
        }
        ListTypeInfo arg1type = HiveUtils.asListTypeInfo(typeInfo[0]);
        if (!HiveUtils.isPrimitiveTypeInfo(arg1type.getListElementTypeInfo())) {
            throw new UDFArgumentTypeException(0, "The first argument `array rankItems` is invalid form: " + typeInfo[0]);
        }
        ListTypeInfo arg2type = HiveUtils.asListTypeInfo(typeInfo[1]);
        if (!HiveUtils.isPrimitiveTypeInfo(arg2type.getListElementTypeInfo())) {
            throw new UDFArgumentTypeException(1, "The first argument `array rankItems` is invalid form: " + typeInfo[1]);
        }
        return new Evaluator();
    }

    public static class NDCGAggregationBuffer
    implements GenericUDAFEvaluator.AggregationBuffer {
        double sum;
        long count;

        void reset() {
            this.sum = 0.0;
            this.count = 0L;
        }

        void merge(double o_sum, long o_count) {
            this.sum += o_sum;
            this.count += o_count;
        }

        double get() {
            if (this.count == 0L) {
                return 0.0;
            }
            return this.sum / (double)this.count;
        }

        void iterate(@Nonnull List<?> rankedList, @Nonnull List<?> correctList) {
            this.sum += BinaryResponsesMeasures.nDCG(rankedList, correctList);
            ++this.count;
        }
    }

    public static class Evaluator
    extends GenericUDAFEvaluator {
        private ListObjectInspector rankedListOI;
        private ListObjectInspector correctListOI;
        private StructObjectInspector internalMergeOI;
        private StructField countField;
        private StructField sumField;

        public ObjectInspector init(GenericUDAFEvaluator.Mode mode, ObjectInspector[] parameters) throws HiveException {
            assert (parameters.length == 2) : parameters.length;
            super.init(mode, parameters);
            if (mode == GenericUDAFEvaluator.Mode.PARTIAL1 || mode == GenericUDAFEvaluator.Mode.COMPLETE) {
                this.rankedListOI = (ListObjectInspector)parameters[0];
                this.correctListOI = (ListObjectInspector)parameters[1];
            } else {
                StructObjectInspector soi;
                this.internalMergeOI = soi = (StructObjectInspector)parameters[0];
                this.countField = soi.getStructFieldRef("count");
                this.sumField = soi.getStructFieldRef("sum");
            }
            Object outputOI = mode == GenericUDAFEvaluator.Mode.PARTIAL1 || mode == GenericUDAFEvaluator.Mode.PARTIAL2 ? Evaluator.internalMergeOI() : PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
            return outputOI;
        }

        private static StructObjectInspector internalMergeOI() {
            ArrayList<String> fieldNames = new ArrayList<String>();
            ArrayList<Object> fieldOIs = new ArrayList<Object>();
            fieldNames.add("sum");
            fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
            fieldNames.add("count");
            fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
            return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
        }

        public GenericUDAFEvaluator.AggregationBuffer getNewAggregationBuffer() throws HiveException {
            NDCGAggregationBuffer myAggr = new NDCGAggregationBuffer();
            this.reset(myAggr);
            return myAggr;
        }

        public void reset(GenericUDAFEvaluator.AggregationBuffer agg) throws HiveException {
            NDCGAggregationBuffer myAggr = (NDCGAggregationBuffer)agg;
            myAggr.reset();
        }

        public void iterate(GenericUDAFEvaluator.AggregationBuffer agg, Object[] parameters) throws HiveException {
            List correctList;
            NDCGAggregationBuffer myAggr = (NDCGAggregationBuffer)agg;
            List rankedList = this.rankedListOI.getList(parameters[0]);
            if (rankedList == null) {
                rankedList = Collections.emptyList();
            }
            if ((correctList = this.correctListOI.getList(parameters[1])) == null) {
                return;
            }
            myAggr.iterate(rankedList, correctList);
        }

        public Object terminatePartial(GenericUDAFEvaluator.AggregationBuffer agg) throws HiveException {
            NDCGAggregationBuffer myAggr = (NDCGAggregationBuffer)agg;
            Object[] partialResult = new Object[]{new DoubleWritable(myAggr.sum), new LongWritable(myAggr.count)};
            return partialResult;
        }

        public void merge(GenericUDAFEvaluator.AggregationBuffer agg, Object partial) throws HiveException {
            if (partial == null) {
                return;
            }
            Object sumObj = this.internalMergeOI.getStructFieldData(partial, this.sumField);
            Object countObj = this.internalMergeOI.getStructFieldData(partial, this.countField);
            double sum = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector.get(sumObj);
            long count = PrimitiveObjectInspectorFactory.writableLongObjectInspector.get(countObj);
            NDCGAggregationBuffer myAggr = (NDCGAggregationBuffer)agg;
            myAggr.merge(sum, count);
        }

        public DoubleWritable terminate(GenericUDAFEvaluator.AggregationBuffer agg) throws HiveException {
            NDCGAggregationBuffer myAggr = (NDCGAggregationBuffer)agg;
            double result = myAggr.get();
            return new DoubleWritable(result);
        }
    }
}

