/*
 * Decompiled with CFR 0.152.
 */
package hivemall.fm;

import hivemall.fm.Entry;
import hivemall.fm.FFMPredictionModel;
import hivemall.fm.FMHyperParameters;
import hivemall.fm.Feature;
import hivemall.fm.FieldAwareFactorizationMachineModel;
import hivemall.utils.buffer.HeapBuffer;
import hivemall.utils.collections.Int2LongOpenHashTable;
import hivemall.utils.lang.NumberUtils;
import hivemall.utils.math.MathUtils;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;

public final class FFMStringFeatureMapModel
extends FieldAwareFactorizationMachineModel {
    private static final int DEFAULT_MAPSIZE = 65536;
    private float _w0 = 0.0f;
    @Nonnull
    private final Int2LongOpenHashTable _map = new Int2LongOpenHashTable(65536);
    private final HeapBuffer _buf = new HeapBuffer(0x400000);
    private final int _numFeatures;
    private final int _numFields;
    private final float _alpha;
    private final float _beta;
    private final float _lambda1;
    private final float _lamdda2;
    private final int _entrySize;

    public FFMStringFeatureMapModel(@Nonnull FMHyperParameters.FFMHyperParameters params) {
        super(params);
        this._numFeatures = params.numFeatures;
        this._numFields = params.numFields;
        this._alpha = params.alphaFTRL;
        this._beta = params.betaFTRL;
        this._lambda1 = params.lambda1;
        this._lamdda2 = params.lamdda2;
        this._entrySize = FFMStringFeatureMapModel.entrySize(this._factor, this._useFTRL, this._useAdaGrad);
    }

    @Nonnull
    FFMPredictionModel toPredictionModel() {
        return new FFMPredictionModel(this._map, this._buf, this._w0, this._factor, this._numFeatures, this._numFields);
    }

    @Override
    public int getSize() {
        return this._map.size();
    }

    @Override
    public float getW0() {
        return this._w0;
    }

    @Override
    protected void setW0(float nextW0) {
        this._w0 = nextW0;
    }

    @Override
    public float getW(@Nonnull Feature x) {
        int j = x.getFeatureIndex();
        Entry entry = this.getEntry(j);
        if (entry == null) {
            return 0.0f;
        }
        return entry.getW();
    }

    @Override
    protected void setW(@Nonnull Feature x, float nextWi) {
        int j = x.getFeatureIndex();
        Entry entry = this.getEntry(j);
        if (entry == null) {
            float[] V = this.initV();
            entry = this.newEntry(nextWi, V);
            long ptr = entry.getOffset();
            this._map.put(j, ptr);
        } else {
            entry.setW(nextWi);
        }
    }

    @Override
    void updateWi(double dloss, @Nonnull Feature x, float eta) {
        double Xi = x.getValue();
        float gradWi = (float)(dloss * Xi);
        Entry theta = this.getEntry(x);
        float wi = theta.getW();
        float nextWi = wi - eta * (gradWi + 2.0f * this._lambdaW * wi);
        if (!NumberUtils.isFinite(nextWi)) {
            throw new IllegalStateException("Got " + nextWi + " for next W[" + x.getFeature() + "]\n" + "Xi=" + Xi + ", gradWi=" + gradWi + ", wi=" + wi + ", dloss=" + dloss + ", eta=" + eta);
        }
        theta.setW(nextWi);
    }

    boolean updateWiFTRL(double dloss, @Nonnull Feature x, float eta) {
        double Xi = x.getValue();
        float gradWi = (float)(dloss * Xi);
        Entry theta = this.getEntry(x);
        float wi = theta.getW();
        float z = theta.updateZ(gradWi, this._alpha);
        double n = theta.updateN(gradWi);
        if (Math.abs(z) <= this._lambda1) {
            this.removeEntry(x);
            return wi != 0.0f;
        }
        float nextWi = (float)((double)(MathUtils.sign(z) * this._lambda1 - z) / (((double)this._beta + Math.sqrt(n)) / (double)this._alpha + (double)this._lamdda2));
        if (!NumberUtils.isFinite(nextWi)) {
            throw new IllegalStateException("Got " + nextWi + " for next W[" + x.getFeature() + "]\n" + "Xi=" + Xi + ", gradWi=" + gradWi + ", wi=" + wi + ", dloss=" + dloss + ", eta=" + eta + ", n=" + n + ", z=" + z);
        }
        theta.setW(nextWi);
        return nextWi != 0.0f || wi != 0.0f;
    }

    @Override
    public float getV(@Nonnull Feature x, @Nonnull int yField, int f) {
        int j = Feature.toIntFeature(x, yField, this._numFields);
        Entry entry = this.getEntry(j);
        if (entry == null) {
            float[] V = this.initV();
            entry = this.newEntry(V);
            long ptr = entry.getOffset();
            this._map.put(j, ptr);
        }
        return entry.getV(f);
    }

    @Override
    protected void setV(@Nonnull Feature x, @Nonnull int yField, int f, float nextVif) {
        int j = Feature.toIntFeature(x, yField, this._numFields);
        Entry entry = this.getEntry(j);
        if (entry == null) {
            float[] V = this.initV();
            entry = this.newEntry(V);
            long ptr = entry.getOffset();
            this._map.put(j, ptr);
        }
        entry.setV(f, nextVif);
    }

    @Override
    protected Entry getEntry(@Nonnull Feature x) {
        int j = x.getFeatureIndex();
        Entry entry = this.getEntry(j);
        if (entry == null) {
            float[] V = this.initV();
            entry = this.newEntry(V);
            long ptr = entry.getOffset();
            this._map.put(j, ptr);
        }
        return entry;
    }

    @Override
    protected Entry getEntry(@Nonnull Feature x, @Nonnull int yField) {
        int j = Feature.toIntFeature(x, yField, this._numFields);
        Entry entry = this.getEntry(j);
        if (entry == null) {
            float[] V = this.initV();
            entry = this.newEntry(V);
            long ptr = entry.getOffset();
            this._map.put(j, ptr);
        }
        return entry;
    }

    protected void removeEntry(@Nonnull Feature x) {
        int j = x.getFeatureIndex();
        this._map.remove(j);
    }

    @Nonnull
    protected final Entry newEntry(float W, @Nonnull float[] V) {
        Entry entry = this.newEntry();
        entry.setW(W);
        entry.setV(V);
        return entry;
    }

    @Nonnull
    protected final Entry newEntry(@Nonnull float[] V) {
        Entry entry = this.newEntry();
        entry.setV(V);
        return entry;
    }

    @Nonnull
    private Entry newEntry() {
        if (this._useFTRL) {
            long ptr = this._buf.allocate(this._entrySize);
            return new Entry.FTRLEntry(this._buf, this._factor, ptr);
        }
        if (this._useAdaGrad) {
            long ptr = this._buf.allocate(this._entrySize);
            return new Entry.AdaGradEntry(this._buf, this._factor, ptr);
        }
        long ptr = this._buf.allocate(this._entrySize);
        return new Entry(this._buf, this._factor, ptr);
    }

    @Nullable
    private Entry getEntry(int key) {
        long ptr = this._map.get(key);
        if (ptr == -1L) {
            return null;
        }
        return this.getEntry(ptr);
    }

    @Nonnull
    private Entry getEntry(long ptr) {
        if (this._useFTRL) {
            return new Entry.FTRLEntry(this._buf, this._factor, ptr);
        }
        if (this._useAdaGrad) {
            return new Entry.AdaGradEntry(this._buf, this._factor, ptr);
        }
        return new Entry(this._buf, this._factor, ptr);
    }

    private static int entrySize(int factors, boolean ftrl, boolean adagrad) {
        if (ftrl) {
            return Entry.FTRLEntry.sizeOf(factors);
        }
        if (adagrad) {
            return Entry.AdaGradEntry.sizeOf(factors);
        }
        return Entry.sizeOf(factors);
    }
}

