/*
 * Decompiled with CFR 0.152.
 */
package hivemall.fm;

import hivemall.fm.Entry;
import hivemall.fm.FMHyperParameters;
import hivemall.fm.FactorizationMachineModel;
import hivemall.fm.Feature;
import hivemall.utils.collections.DoubleArray3D;
import hivemall.utils.collections.IntArrayList;
import hivemall.utils.lang.NumberUtils;
import java.util.Arrays;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.hadoop.hive.ql.metadata.HiveException;

public abstract class FieldAwareFactorizationMachineModel
extends FactorizationMachineModel {
    @Nonnull
    protected final FMHyperParameters.FFMHyperParameters _params;
    protected final float _eta0_V;
    protected final float _eps;
    protected final boolean _useAdaGrad;
    protected final boolean _useFTRL;

    public FieldAwareFactorizationMachineModel(@Nonnull FMHyperParameters.FFMHyperParameters params) {
        super(params);
        this._params = params;
        this._eta0_V = params.eta0_V;
        this._eps = params.eps;
        this._useAdaGrad = params.useAdaGrad;
        this._useFTRL = params.useFTRL;
    }

    public abstract float getV(@Nonnull Feature var1, @Nonnull int var2, int var3);

    @Deprecated
    protected abstract void setV(@Nonnull Feature var1, @Nonnull int var2, int var3, float var4);

    @Override
    public float getV(Feature x, int f) {
        throw new UnsupportedOperationException();
    }

    @Override
    protected void setV(Feature x, int f, float nextVif) {
        throw new UnsupportedOperationException();
    }

    @Override
    protected final double predict(@Nonnull Feature[] x) throws HiveException {
        double ret = this.getW0();
        for (Feature e : x) {
            double xi = e.getValue();
            float wi = this.getW(e);
            double wx = (double)wi * xi;
            ret += wx;
        }
        for (int i = 0; i < x.length; ++i) {
            Feature ei = x[i];
            double xi = ei.getValue();
            short iField = ei.getField();
            for (int j = i + 1; j < x.length; ++j) {
                Feature ej = x[j];
                double xj = ej.getValue();
                short jField = ej.getField();
                int k = this._factor;
                for (int f = 0; f < k; ++f) {
                    float vijf = this.getV(ei, jField, f);
                    float vjif = this.getV(ej, iField, f);
                    assert (!Double.isNaN(ret += (double)(vijf * vjif) * xi * xj));
                }
            }
        }
        if (!NumberUtils.isFinite(ret)) {
            throw new HiveException("Detected " + ret + " in predict. We recommend to normalize training examples.\n" + "Dumping variables ...\n" + this.varDump(x));
        }
        return ret;
    }

    void updateV(double dloss, @Nonnull Feature x, @Nonnull int yField, int f, double sumViX, long t) {
        float eta;
        double Xi = x.getValue();
        double h = Xi * sumViX;
        float gradV = (float)(dloss * h);
        float lambdaVf = this.getLambdaV(f);
        Entry theta = this.getEntry(x, yField);
        float currentV = theta.getV(f);
        float nextV = currentV - (eta = this.etaV(theta, t, gradV)) * (gradV + 2.0f * lambdaVf * currentV);
        if (!NumberUtils.isFinite(nextV)) {
            throw new IllegalStateException("Got " + nextV + " for next V" + f + '[' + x.getFeatureIndex() + "]\n" + "Xi=" + Xi + ", Vif=" + currentV + ", h=" + h + ", gradV=" + gradV + ", lambdaVf=" + lambdaVf + ", dloss=" + dloss + ", sumViX=" + sumViX);
        }
        theta.setV(f, nextV);
    }

    protected final float etaV(@Nonnull Entry theta, long t, float grad) {
        if (this._useAdaGrad) {
            double gg = theta.getSumOfSquaredGradientsV();
            theta.addGradientV(grad);
            return (float)((double)this._eta0_V / Math.sqrt((double)this._eps + gg));
        }
        return this._eta.eta(t);
    }

    @Nonnull
    final DoubleArray3D sumVfX(@Nonnull Feature[] x, @Nonnull IntArrayList fieldList, @Nullable DoubleArray3D cached) {
        DoubleArray3D mdarray;
        int xSize = x.length;
        int fieldSize = fieldList.size();
        int factors = this._factor;
        if (cached == null) {
            mdarray = new DoubleArray3D();
            mdarray.setSanityCheck(false);
        } else {
            mdarray = cached;
        }
        mdarray.configure(xSize, fieldSize, factors);
        for (int i = 0; i < xSize; ++i) {
            for (int fieldIndex = 0; fieldIndex < fieldSize; ++fieldIndex) {
                int yField = fieldList.get(fieldIndex);
                for (int f = 0; f < factors; ++f) {
                    double val = this.sumVfX(x, i, yField, f);
                    mdarray.set(i, fieldIndex, f, val);
                }
            }
        }
        return mdarray;
    }

    private double sumVfX(@Nonnull Feature[] x, int i, @Nonnull int yField, int f) {
        Feature xi = x[i];
        int xiFeature = xi.getFeatureIndex();
        double xiValue = xi.getValue();
        short xiField = xi.getField();
        double ret = 0.0;
        for (Feature e : x) {
            if (e.getFeatureIndex() == xiFeature || e.getField() != yField) continue;
            float Vjf = this.getV(e, xiField, f);
            ret += (double)Vjf * xiValue;
        }
        if (!NumberUtils.isFinite(ret)) {
            throw new IllegalStateException("Got " + ret + " for sumV[ " + i + "][ " + f + "]X.\n" + "x = " + Arrays.toString(x));
        }
        return ret;
    }

    @Nonnull
    protected abstract Entry getEntry(@Nonnull Feature var1);

    @Nonnull
    protected abstract Entry getEntry(@Nonnull Feature var1, @Nonnull int var2);

    @Override
    protected final String varDump(@Nonnull Feature[] x) {
        StringBuilder buf1 = new StringBuilder(1024);
        StringBuilder buf2 = new StringBuilder(1024);
        for (int i = 0; i < x.length; ++i) {
            Feature e = x[i];
            Feature[] j = e.getFeature();
            double xj = e.getValue();
            if (i != 0) {
                buf1.append(", ");
            }
            buf1.append("x[").append((String)j).append("] = ").append(xj);
        }
        buf1.append("\n");
        double ret = this.getW0();
        buf1.append("predict(x) = w0");
        buf2.append("predict(x) = ").append(ret);
        for (Feature e : x) {
            String i = e.getFeature();
            double xi = e.getValue();
            float wi = this.getW(e);
            buf1.append(" + (w[").append(i).append("] * x[").append(i).append("])");
            buf2.append(" + (").append(wi).append(" * ").append(xi).append(')');
            double wx = (double)wi * xi;
            if (NumberUtils.isFinite(ret += wx)) continue;
            return buf1.append(" + ... = ").append(ret).append('\n').append((CharSequence)buf2).append(" + ... = ").append(ret).toString();
        }
        for (int i = 0; i < x.length; ++i) {
            Feature ei = x[i];
            String fi = ei.getFeature();
            double xi = ei.getValue();
            short iField = ei.getField();
            for (int j = i + 1; j < x.length; ++j) {
                Feature ej = x[j];
                String fj = ej.getFeature();
                double xj = ej.getValue();
                short jField = ej.getField();
                int k = this._factor;
                for (int f = 0; f < k; ++f) {
                    float vijf = this.getV(ei, jField, f);
                    float vjif = this.getV(ej, iField, f);
                    buf1.append(" + (v[i").append(fi).append("-j").append(jField).append("-f").append(f).append("] * v[j").append(fj).append("-i").append(iField).append("-f").append(f).append("] * x[").append(fi).append("] * x[").append(fj).append("])");
                    buf2.append(" + (").append(vijf).append(" * ").append(vjif).append(" * ").append(xi).append(" * ").append(xj).append(')');
                    if (NumberUtils.isFinite(ret += (double)(vijf * vjif) * xi * xj)) continue;
                    return buf1.append(" + ... = ").append(ret).append('\n').append((CharSequence)buf2).append(" + ... = ").append(ret).toString();
                }
            }
        }
        return buf1.append(" = ").append(ret).append('\n').append((CharSequence)buf2).append(" = ").append(ret).toString();
    }
}

