/*
 * Decompiled with CFR 0.152.
 */
package hivemall.evaluation;

import hivemall.utils.hadoop.WritableUtils;
import java.util.List;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDAF;
import org.apache.hadoop.hive.ql.exec.UDAFEvaluator;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.io.IntWritable;

@Description(name="f1score", value="_FUNC_(array[int], array[int]) - Return a F-measure/F1 score")
public final class FMeasureUDAF
extends UDAF {

    public static class Evaluator
    implements UDAFEvaluator {
        private PartialResult partial;

        public void init() {
            this.partial = null;
        }

        public boolean iterate(List<IntWritable> actual, List<IntWritable> predicted) {
            if (this.partial == null) {
                this.partial = new PartialResult();
            }
            this.partial.updateScore(actual, predicted);
            return true;
        }

        public PartialResult terminatePartial() {
            return this.partial;
        }

        public boolean merge(PartialResult other) {
            if (other == null) {
                return true;
            }
            if (this.partial == null) {
                this.partial = new PartialResult();
            }
            this.partial.merge(other);
            return true;
        }

        public DoubleWritable terminate() {
            if (this.partial == null) {
                return null;
            }
            double score = Evaluator.f1Score(this.partial);
            return WritableUtils.val(score);
        }

        private static double f1Score(PartialResult partial) {
            double recall;
            double precision = Evaluator.precision(partial);
            double divisor = precision + (recall = Evaluator.recall(partial));
            if (divisor > 0.0) {
                return 2.0 * precision * recall / divisor;
            }
            return -1.0;
        }

        private static double precision(PartialResult partial) {
            return partial.totalPredicted == 0L ? 0.0 : (double)partial.tp / (double)partial.totalPredicted;
        }

        private static double recall(PartialResult partial) {
            return partial.totalAcutal == 0L ? 0.0 : (double)partial.tp / (double)partial.totalAcutal;
        }

        public static class PartialResult {
            long tp = 0L;
            long totalAcutal = 0L;
            long totalPredicted = 0L;

            PartialResult() {
            }

            void updateScore(List<IntWritable> actual, List<IntWritable> predicted) {
                int numActual = actual.size();
                int numPredicted = predicted.size();
                int countTp = 0;
                for (int i = 0; i < numPredicted; ++i) {
                    IntWritable p = predicted.get(i);
                    if (!actual.contains(p)) continue;
                    ++countTp;
                }
                this.tp += (long)countTp;
                this.totalAcutal += (long)numActual;
                this.totalPredicted += (long)numPredicted;
            }

            void merge(PartialResult other) {
                this.tp = other.tp;
                this.totalAcutal = other.totalAcutal;
                this.totalPredicted = other.totalPredicted;
            }
        }
    }
}

