/*
 * Decompiled with CFR 0.152.
 */
package hivemall.common;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public final class ConversionState {
    private static final Log logger = LogFactory.getLog(ConversionState.class);
    protected final boolean conversionCheck;
    protected final double convergenceRate;
    protected boolean readyToFinishIterations;
    protected double totalErrors;
    protected double currLosses;
    protected double prevLosses;
    protected int curIter;
    protected float curEta;

    public ConversionState() {
        this(true, 0.005);
    }

    public ConversionState(boolean conversionCheck, double convergenceRate) {
        this.conversionCheck = conversionCheck;
        this.convergenceRate = convergenceRate;
        this.readyToFinishIterations = false;
        this.totalErrors = 0.0;
        this.currLosses = 0.0;
        this.prevLosses = Double.POSITIVE_INFINITY;
        this.curIter = 0;
        this.curEta = Float.NaN;
    }

    public double getTotalErrors() {
        return this.totalErrors;
    }

    public double getCumulativeLoss() {
        return this.currLosses;
    }

    public double getPreviousLoss() {
        return this.prevLosses;
    }

    public void incrError(double error) {
        this.totalErrors += error;
    }

    public void incrLoss(double loss) {
        this.currLosses += loss;
    }

    public void multiplyLoss(double multi) {
        this.currLosses *= multi;
    }

    public boolean isLossIncreased() {
        return this.currLosses > this.prevLosses;
    }

    public boolean isConverged(int iter, long obserbedTrainingExamples) {
        if (!this.conversionCheck) {
            this.prevLosses = this.currLosses;
            this.currLosses = 0.0;
            return false;
        }
        if (this.currLosses > this.prevLosses) {
            if (logger.isInfoEnabled()) {
                logger.info((Object)("Iteration #" + iter + " currLoss `" + this.currLosses + "` > prevLosses `" + this.prevLosses + '`'));
            }
            this.prevLosses = this.currLosses;
            this.currLosses = 0.0;
            this.readyToFinishIterations = false;
            return false;
        }
        double changeRate = (this.prevLosses - this.currLosses) / this.prevLosses;
        if (changeRate < this.convergenceRate) {
            if (this.readyToFinishIterations) {
                logger.info((Object)("Training converged at " + iter + "-th iteration. [curLosses=" + this.currLosses + ", prevLosses=" + this.prevLosses + ", changeRate=" + changeRate + ']'));
                return true;
            }
            this.readyToFinishIterations = true;
        } else {
            if (logger.isDebugEnabled()) {
                logger.debug((Object)("Iteration #" + iter + " [curLosses=" + this.currLosses + ", prevLosses=" + this.prevLosses + ", changeRate=" + changeRate + ", #trainingExamples=" + obserbedTrainingExamples + ']'));
            }
            this.readyToFinishIterations = false;
        }
        this.prevLosses = this.currLosses;
        this.currLosses = 0.0;
        return false;
    }

    public void logState(int iter, float eta) {
        if (logger.isInfoEnabled()) {
            logger.info((Object)("Iteration #" + iter + " [curLoss=" + this.currLosses + ", prevLoss=" + this.prevLosses + ", eta=" + eta + ']'));
        }
        this.curIter = iter;
        this.curEta = eta;
    }

    public int getCurrentIteration() {
        return this.curIter;
    }

    public float getCurrentEta() {
        return this.curEta;
    }
}

