/*
 * Decompiled with CFR 0.152.
 */
package hivemall.mix.client;

import hivemall.mix.MixMessage;
import hivemall.mix.MixedModel;
import hivemall.mix.MixedWeight;
import hivemall.mix.NodeInfo;
import hivemall.mix.client.MixClientHandler;
import hivemall.mix.client.MixClientInitializer;
import hivemall.mix.client.MixRequestRouter;
import hivemall.model.ModelUpdateHandler;
import hivemall.utils.hadoop.HadoopUtils;
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import java.io.Closeable;
import java.io.IOException;
import java.net.SocketAddress;
import java.util.HashMap;
import java.util.Map;
import javax.annotation.CheckForNull;
import javax.annotation.Nonnull;
import javax.net.ssl.SSLException;

public final class MixClient
implements ModelUpdateHandler,
Closeable {
    public static final String DUMMY_JOB_ID = "__DUMMY_JOB_ID__";
    private final MixMessage.MixEventName event;
    private String groupID;
    private final boolean ssl;
    private final int mixThreshold;
    private final MixRequestRouter router;
    private final MixClientHandler msgHandler;
    private final Map<NodeInfo, Channel> channelMap;
    private boolean initialized = false;
    private EventLoopGroup workers;

    public MixClient(@Nonnull MixMessage.MixEventName event, @CheckForNull String groupID, @Nonnull String connectURIs, boolean ssl, int mixThreshold, @Nonnull MixedModel model) {
        if (groupID == null) {
            throw new IllegalArgumentException("groupID is null");
        }
        if (mixThreshold < 1 || mixThreshold > 127) {
            throw new IllegalArgumentException("Invalid mixThreshold: " + mixThreshold);
        }
        this.event = event;
        this.groupID = groupID;
        this.router = new MixRequestRouter(connectURIs);
        this.ssl = ssl;
        this.mixThreshold = mixThreshold;
        this.msgHandler = new MixClientHandler(model);
        this.channelMap = new HashMap<NodeInfo, Channel>();
    }

    private void initialize() throws Exception {
        NodeInfo[] serverNodes;
        NioEventLoopGroup workerGroup = new NioEventLoopGroup();
        for (NodeInfo node : serverNodes = this.router.getAllNodes()) {
            Bootstrap b = new Bootstrap();
            this.configureBootstrap(b, workerGroup, node);
        }
        this.workers = workerGroup;
        this.initialized = true;
    }

    private void configureBootstrap(Bootstrap b, EventLoopGroup workerGroup, NodeInfo server) throws SSLException, InterruptedException {
        SslContext sslCtx = this.ssl ? SslContext.newClientContext(InsecureTrustManagerFactory.INSTANCE) : null;
        b.group(workerGroup);
        b.option(ChannelOption.SO_KEEPALIVE, true);
        b.option(ChannelOption.TCP_NODELAY, true);
        b.channel(NioSocketChannel.class);
        b.handler(new MixClientInitializer(this.msgHandler, sslCtx));
        SocketAddress remoteAddr = server.getSocketAddress();
        ChannelFuture channelFuture = b.connect(remoteAddr).sync();
        Channel channel = channelFuture.channel();
        this.channelMap.put(server, channel);
    }

    @Override
    public boolean onUpdate(Object feature, float weight, float covar, short clock, int deltaUpdates) throws Exception {
        assert (deltaUpdates > 0) : deltaUpdates;
        if (deltaUpdates < this.mixThreshold) {
            return false;
        }
        if (!this.initialized) {
            this.replaceGroupIDIfRequired();
            this.initialize();
        }
        MixMessage msg = new MixMessage(this.event, feature, weight, covar, clock, deltaUpdates);
        msg.setGroupID(this.groupID);
        NodeInfo server = this.router.selectNode(msg);
        Channel ch = this.channelMap.get(server);
        if (!ch.isActive()) {
            SocketAddress remoteAddr = server.getSocketAddress();
            ch.connect(remoteAddr).sync();
        }
        ch.writeAndFlush(msg);
        return true;
    }

    @Override
    public void sendCancelRequest(@Nonnull Object feature, @Nonnull MixedWeight mixed) throws Exception {
        assert (this.initialized);
        float weight = mixed.getWeight();
        float covar = mixed.getCovar();
        int deltaUpdates = mixed.getDeltaUpdates();
        MixMessage msg = new MixMessage(this.event, feature, weight, covar, deltaUpdates, true);
        assert (this.groupID != null);
        msg.setGroupID(this.groupID);
        NodeInfo server = this.router.selectNode(msg);
        Channel ch = this.channelMap.get(server);
        if (!ch.isActive()) {
            SocketAddress remoteAddr = server.getSocketAddress();
            ch.connect(remoteAddr).sync();
        }
        ch.writeAndFlush(msg);
    }

    private void replaceGroupIDIfRequired() {
        if (this.groupID.startsWith(DUMMY_JOB_ID)) {
            String jobId = HadoopUtils.getJobId();
            this.groupID = this.groupID.replace(DUMMY_JOB_ID, jobId);
        }
    }

    @Override
    public void close() throws IOException {
        if (this.workers != null) {
            for (Channel ch : this.channelMap.values()) {
                ch.close();
            }
            this.channelMap.clear();
            this.workers.shutdownGracefully();
            this.workers = null;
        }
    }
}

