/*
 * Decompiled with CFR 0.152.
 */
package hivemall.classifier;

import hivemall.LearnerBaseUDTF;
import hivemall.model.FeatureValue;
import hivemall.model.PredictionModel;
import hivemall.model.PredictionResult;
import hivemall.model.WeightValue;
import hivemall.utils.collections.IMapIterator;
import hivemall.utils.hadoop.HiveUtils;
import java.util.ArrayList;
import java.util.List;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaIntObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.io.FloatWritable;

public abstract class BinaryOnlineClassifierUDTF
extends LearnerBaseUDTF {
    private static final Log logger = LogFactory.getLog(BinaryOnlineClassifierUDTF.class);
    private ListObjectInspector featureListOI;
    private PrimitiveObjectInspector labelOI;
    private boolean parseFeature;
    protected PredictionModel model;
    protected int count;

    public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
        if (argOIs.length < 2) {
            throw new UDFArgumentException(((Object)((Object)this)).getClass().getSimpleName() + " takes 2 arguments: List<Int|BigInt|Text> features, int label [, constant string options]");
        }
        PrimitiveObjectInspector featureInputOI = this.processFeaturesOI(argOIs[0]);
        this.labelOI = HiveUtils.asIntCompatibleOI(argOIs[1]);
        this.processOptions(argOIs);
        JavaIntObjectInspector featureOutputOI = this.dense_model ? PrimitiveObjectInspectorFactory.javaIntObjectInspector : featureInputOI;
        this.model = this.createModel();
        if (this.preloadedModelFile != null) {
            this.loadPredictionModel(this.model, this.preloadedModelFile, (PrimitiveObjectInspector)featureOutputOI);
        }
        this.count = 0;
        return this.getReturnOI((ObjectInspector)featureOutputOI);
    }

    protected PrimitiveObjectInspector processFeaturesOI(ObjectInspector arg) throws UDFArgumentException {
        this.featureListOI = (ListObjectInspector)arg;
        ObjectInspector featureRawOI = this.featureListOI.getListElementObjectInspector();
        HiveUtils.validateFeatureOI(featureRawOI);
        this.parseFeature = HiveUtils.isStringOI(featureRawOI);
        return HiveUtils.asPrimitiveObjectInspector(featureRawOI);
    }

    protected StructObjectInspector getReturnOI(ObjectInspector featureRawOI) {
        ArrayList<String> fieldNames = new ArrayList<String>();
        ArrayList<Object> fieldOIs = new ArrayList<Object>();
        fieldNames.add("feature");
        ObjectInspector featureOI = ObjectInspectorUtils.getStandardObjectInspector((ObjectInspector)featureRawOI);
        fieldOIs.add(featureOI);
        fieldNames.add("weight");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
        if (this.useCovariance()) {
            fieldNames.add("covar");
            fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
        }
        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
    }

    public void process(Object[] args) throws HiveException {
        List features = this.featureListOI.getList(args[0]);
        FeatureValue[] featureVector = this.parseFeatures(features);
        if (featureVector == null) {
            return;
        }
        int label = PrimitiveObjectInspectorUtils.getInt((Object)args[1], (PrimitiveObjectInspector)this.labelOI);
        this.checkLabelValue(label);
        ++this.count;
        this.train(featureVector, label);
    }

    @Nullable
    protected final FeatureValue[] parseFeatures(@Nonnull List<?> features) {
        int size = features.size();
        if (size == 0) {
            return null;
        }
        ObjectInspector featureInspector = this.featureListOI.getListElementObjectInspector();
        FeatureValue[] featureVector = new FeatureValue[size];
        for (int i = 0; i < size; ++i) {
            FeatureValue fv;
            Object f = features.get(i);
            if (f == null) continue;
            if (this.parseFeature) {
                fv = FeatureValue.parse(f);
            } else {
                Object k = ObjectInspectorUtils.copyToStandardObject(f, (ObjectInspector)featureInspector);
                fv = new FeatureValue(k, 1.0f);
            }
            featureVector[i] = fv;
        }
        return featureVector;
    }

    protected void checkLabelValue(int label) throws UDFArgumentException {
        assert (label == -1 || label == 0 || label == 1) : label;
    }

    void train(List<?> features, int label) {
        FeatureValue[] featureVector = this.parseFeatures(features);
        this.train(featureVector, label);
    }

    protected void train(@Nonnull FeatureValue[] features, int label) {
        float y = label > 0 ? 1.0f : -1.0f;
        float p = this.predict(features);
        float z = p * y;
        if (z <= 0.0f) {
            this.update(features, y, p);
        }
    }

    protected float predict(@Nonnull FeatureValue[] features) {
        float score = 0.0f;
        for (FeatureValue f : features) {
            Object k;
            float old_w;
            if (f == null || (old_w = this.model.getWeight(k = f.getFeature())) == 0.0f) continue;
            float v = f.getValueAsFloat();
            score += old_w * v;
        }
        return score;
    }

    @Nonnull
    protected PredictionResult calcScoreAndNorm(@Nonnull FeatureValue[] features) {
        float score = 0.0f;
        float squared_norm = 0.0f;
        for (FeatureValue f : features) {
            if (f == null) continue;
            Object k = f.getFeature();
            float v = f.getValueAsFloat();
            float old_w = this.model.getWeight(k);
            if (old_w != 0.0f) {
                score += old_w * v;
            }
            squared_norm += v * v;
        }
        return new PredictionResult(score).squaredNorm(squared_norm);
    }

    @Nonnull
    protected PredictionResult calcScoreAndVariance(@Nonnull FeatureValue[] features) {
        float score = 0.0f;
        float variance = 0.0f;
        for (FeatureValue f : features) {
            if (f == null) continue;
            Object k = f.getFeature();
            float v = f.getValueAsFloat();
            Object old_w = this.model.get(k);
            if (old_w == null) {
                variance += 1.0f * v * v;
                continue;
            }
            score += old_w.get() * v;
            variance += old_w.getCovariance() * v * v;
        }
        return new PredictionResult(score).variance(variance);
    }

    protected void update(@Nonnull FeatureValue[] features, float y, float p) {
        throw new IllegalStateException("update() should not be called");
    }

    protected void update(@Nonnull FeatureValue[] features, float coeff) {
        for (FeatureValue f : features) {
            if (f == null) continue;
            Object k = f.getFeature();
            float v = f.getValueAsFloat();
            float old_w = this.model.getWeight(k);
            float new_w = old_w + coeff * v;
            this.model.set(k, new WeightValue(new_w));
        }
    }

    @Override
    public final void close() throws HiveException {
        super.close();
        if (this.model != null) {
            WeightValue probe;
            int numForwarded = 0;
            if (this.useCovariance()) {
                probe = new WeightValue.WeightValueWithCovar();
                Object[] forwardMapObj = new Object[3];
                FloatWritable fv = new FloatWritable();
                FloatWritable cov = new FloatWritable();
                IMapIterator itor = this.model.entries();
                while (itor.next() != -1) {
                    itor.getValue(probe);
                    if (!probe.isTouched()) continue;
                    Object k = itor.getKey();
                    fv.set(probe.get());
                    cov.set(((WeightValue.WeightValueWithCovar)probe).getCovariance());
                    forwardMapObj[0] = k;
                    forwardMapObj[1] = fv;
                    forwardMapObj[2] = cov;
                    this.forward(forwardMapObj);
                    ++numForwarded;
                }
            } else {
                probe = new WeightValue();
                Object[] forwardMapObj = new Object[2];
                FloatWritable fv = new FloatWritable();
                IMapIterator itor = this.model.entries();
                while (itor.next() != -1) {
                    itor.getValue(probe);
                    if (!probe.isTouched()) continue;
                    Object k = itor.getKey();
                    fv.set(probe.get());
                    forwardMapObj[0] = k;
                    forwardMapObj[1] = fv;
                    this.forward(forwardMapObj);
                    ++numForwarded;
                }
            }
            long numMixed = this.model.getNumMixed();
            this.model = null;
            logger.info((Object)("Trained a prediction model using " + this.count + " training examples" + (numMixed > 0L ? "( numMixed: " + numMixed + " )" : "")));
            logger.info((Object)("Forwarded the prediction model of " + numForwarded + " rows"));
        }
    }
}

