/*
 * Decompiled with CFR 0.152.
 */
package hivemall.fm;

import hivemall.fm.Entry;
import hivemall.fm.Feature;
import hivemall.utils.buffer.HeapBuffer;
import hivemall.utils.codec.VariableByteCodec;
import hivemall.utils.codec.ZigZagLEB128Codec;
import hivemall.utils.collections.Int2LongOpenHashTable;
import hivemall.utils.io.CompressionStreamFactory;
import hivemall.utils.io.IOUtils;
import hivemall.utils.lang.ArrayUtils;
import hivemall.utils.lang.HalfFloat;
import hivemall.utils.lang.ObjectUtils;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.Externalizable;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.Arrays;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public final class FFMPredictionModel
implements Externalizable {
    private static final Log LOG = LogFactory.getLog(FFMPredictionModel.class);
    private static final byte HALF_FLOAT_ENTRY = 1;
    private static final byte W_ONLY_HALF_FLOAT_ENTRY = 2;
    private static final byte FLOAT_ENTRY = 3;
    private static final byte W_ONLY_FLOAT_ENTRY = 4;
    private Int2LongOpenHashTable _map;
    private HeapBuffer _buf;
    private double _w0;
    private int _factors;
    private int _numFeatures;
    private int _numFields;

    public FFMPredictionModel() {
    }

    public FFMPredictionModel(@Nonnull Int2LongOpenHashTable map, @Nonnull HeapBuffer buf, double w0, int factor, int numFeatures, int numFields) {
        this._map = map;
        this._buf = buf;
        this._w0 = w0;
        this._factors = factor;
        this._numFeatures = numFeatures;
        this._numFields = numFields;
    }

    public int getNumFactors() {
        return this._factors;
    }

    public double getW0() {
        return this._w0;
    }

    public int getNumFeatures() {
        return this._numFeatures;
    }

    public int getNumFields() {
        return this._numFields;
    }

    public int getActualNumFeatures() {
        return this._map.size();
    }

    public long approxBytesConsumed() {
        int size = this._map.size();
        long bytes = (long)size * (9L + 4L * (long)this._factors);
        int rest = this._map.capacity() - size;
        if (rest > 0) {
            bytes += (long)rest * 1L;
        }
        return bytes += 28L;
    }

    @Nullable
    private Entry getEntry(int key) {
        long ptr = this._map.get(key);
        if (ptr == -1L) {
            return null;
        }
        return new Entry(this._buf, this._factors, ptr);
    }

    public float getW(@Nonnull Feature x) {
        int j = x.getFeatureIndex();
        Entry entry = this.getEntry(j);
        if (entry == null) {
            return 0.0f;
        }
        return entry.getW();
    }

    public boolean getV(@Nonnull Feature x, @Nonnull int yField, @Nonnull float[] dst) {
        int j = Feature.toIntFeature(x, yField, this._numFields);
        Entry entry = this.getEntry(j);
        if (entry == null) {
            return false;
        }
        entry.getV(dst);
        return !ArrayUtils.equals(dst, 0.0f);
    }

    @Override
    public void writeExternal(@Nonnull ObjectOutput out) throws IOException {
        out.writeDouble(this._w0);
        int factors = this._factors;
        out.writeInt(factors);
        out.writeInt(this._numFeatures);
        out.writeInt(this._numFields);
        int used = this._map.size();
        out.writeInt(used);
        int[] keys = this._map.getKeys();
        int size = keys.length;
        out.writeInt(size);
        byte[] states = this._map.getStates();
        FFMPredictionModel.writeStates(states, out);
        long[] values = this._map.getValues();
        HeapBuffer buf = this._buf;
        Entry e = new Entry(buf, factors);
        float[] Vf = new float[factors];
        for (int i = 0; i < size; ++i) {
            if (states[i] != 1) continue;
            ZigZagLEB128Codec.writeSignedInt(keys[i], out);
            e.setOffset(values[i]);
            FFMPredictionModel.writeEntry(e, factors, Vf, out);
        }
        this._map = null;
        this._buf = null;
    }

    private static void writeEntry(@Nonnull Entry e, int factors, @Nonnull float[] Vf, @Nonnull DataOutput out) throws IOException {
        float W = e.getW();
        e.getV(Vf);
        if (ArrayUtils.almostEquals(Vf, 0.0f)) {
            if (HalfFloat.isRepresentable(W)) {
                out.writeByte(2);
                out.writeShort(HalfFloat.floatToHalfFloat(W));
            } else {
                out.writeByte(4);
                out.writeFloat(W);
            }
        } else if (FFMPredictionModel.isRepresentableAsHalfFloat(W, Vf)) {
            out.writeByte(1);
            out.writeShort(HalfFloat.floatToHalfFloat(W));
            for (int i = 0; i < factors; ++i) {
                out.writeShort(HalfFloat.floatToHalfFloat(Vf[i]));
            }
        } else {
            out.writeByte(3);
            out.writeFloat(W);
            IOUtils.writeFloats(Vf, factors, out);
        }
    }

    private static boolean isRepresentableAsHalfFloat(float W, @Nonnull float[] Vf) {
        if (!HalfFloat.isRepresentable(W)) {
            return false;
        }
        for (float V : Vf) {
            if (HalfFloat.isRepresentable(V)) continue;
            return false;
        }
        return true;
    }

    @Nonnull
    static void writeStates(@Nonnull byte[] status, @Nonnull DataOutput out) throws IOException {
        int size = status.length;
        int cardinarity = 0;
        for (int i = 0; i < size; ++i) {
            if (status[i] == 1) continue;
            ++cardinarity;
        }
        out.writeInt(cardinarity);
        if (cardinarity == 0) {
            return;
        }
        int prev = 0;
        for (int i = 0; i < size; ++i) {
            if (status[i] == 1) continue;
            int diff = i - prev;
            assert (diff >= 0);
            VariableByteCodec.encodeUnsignedInt(diff, out);
            prev = i;
        }
    }

    @Override
    public void readExternal(@Nonnull ObjectInput in) throws IOException, ClassNotFoundException {
        int factors;
        this._w0 = in.readDouble();
        this._factors = factors = in.readInt();
        this._numFeatures = in.readInt();
        this._numFields = in.readInt();
        int used = in.readInt();
        int size = in.readInt();
        int[] keys = new int[size];
        long[] values = new long[size];
        byte[] states = new byte[size];
        FFMPredictionModel.readStates(in, states);
        int entrySize = Entry.sizeOf(factors);
        int numChunks = entrySize * used / 0x1000000 + 1;
        HeapBuffer buf = new HeapBuffer(0x400000, numChunks);
        Entry e = new Entry(buf, factors);
        float[] Vf = new float[factors];
        for (int i = 0; i < size; ++i) {
            if (states[i] != 1) continue;
            keys[i] = ZigZagLEB128Codec.readSignedInt(in);
            long ptr = buf.allocate(entrySize);
            e.setOffset(ptr);
            FFMPredictionModel.readEntry(in, factors, Vf, e);
            values[i] = ptr;
        }
        this._map = new Int2LongOpenHashTable(keys, values, states, used);
        this._buf = buf;
    }

    @Nonnull
    private static void readEntry(@Nonnull DataInput in, int factors, @Nonnull float[] Vf, @Nonnull Entry dst) throws IOException {
        byte type = in.readByte();
        switch (type) {
            case 1: {
                float W = HalfFloat.halfFloatToFloat(in.readShort());
                dst.setW(W);
                for (int i = 0; i < factors; ++i) {
                    Vf[i] = HalfFloat.halfFloatToFloat(in.readShort());
                }
                dst.setV(Vf);
                break;
            }
            case 2: {
                float W = HalfFloat.halfFloatToFloat(in.readShort());
                dst.setW(W);
                break;
            }
            case 3: {
                float W = in.readFloat();
                dst.setW(W);
                IOUtils.readFloats(in, Vf);
                dst.setV(Vf);
                break;
            }
            case 4: {
                float W = in.readFloat();
                dst.setW(W);
                break;
            }
            default: {
                throw new IOException("Unexpected Entry type: " + type);
            }
        }
    }

    @Nonnull
    static void readStates(@Nonnull DataInput in, @Nonnull byte[] status) throws IOException {
        int cardinarity = in.readInt();
        Arrays.fill(status, (byte)1);
        int prev = 0;
        for (int j = 0; j < cardinarity; ++j) {
            int i = VariableByteCodec.decodeUnsignedInt(in) + prev;
            status[i] = 0;
            prev = i;
        }
    }

    public byte[] serialize() throws IOException {
        LOG.info((Object)("FFMPredictionModel#serialize(): " + this._buf.toString()));
        return ObjectUtils.toCompressedBytes(this, CompressionStreamFactory.CompressionAlgorithm.lzma2, true);
    }

    public static FFMPredictionModel deserialize(@Nonnull byte[] serializedObj, int len) throws ClassNotFoundException, IOException {
        FFMPredictionModel model = new FFMPredictionModel();
        ObjectUtils.readCompressedObject(serializedObj, len, model, CompressionStreamFactory.CompressionAlgorithm.lzma2, true);
        LOG.info((Object)("FFMPredictionModel#deserialize(): " + model._buf.toString()));
        return model;
    }
}

