/*
 * Decompiled with CFR 0.152.
 */
package hivemall.common;

import hivemall.utils.math.MathUtils;

public final class LossFunctions {
    public static LossFunction getLossFunction(String type) {
        if ("SquaredLoss".equalsIgnoreCase(type)) {
            return new SquaredLoss();
        }
        if ("LogLoss".equalsIgnoreCase(type)) {
            return new LogLoss();
        }
        if ("HingeLoss".equalsIgnoreCase(type)) {
            return new HingeLoss();
        }
        if ("SquaredHingeLoss".equalsIgnoreCase(type)) {
            return new SquaredHingeLoss();
        }
        if ("QuantileLoss".equalsIgnoreCase(type)) {
            return new QuantileLoss();
        }
        if ("EpsilonInsensitiveLoss".equalsIgnoreCase(type)) {
            return new EpsilonInsensitiveLoss();
        }
        throw new IllegalArgumentException("Unsupported type: " + type);
    }

    public static LossFunction getLossFunction(LossType type) {
        switch (type) {
            case SquaredLoss: {
                return new SquaredLoss();
            }
            case LogLoss: {
                return new LogLoss();
            }
            case HingeLoss: {
                return new HingeLoss();
            }
            case SquaredHingeLoss: {
                return new SquaredHingeLoss();
            }
            case QuantileLoss: {
                return new QuantileLoss();
            }
            case EpsilonInsensitiveLoss: {
                return new EpsilonInsensitiveLoss();
            }
        }
        throw new IllegalArgumentException("Unsupported type: " + (Object)((Object)type));
    }

    public static float logisticLoss(float target, float predicted) {
        if ((double)predicted > -100.0) {
            return target - (float)MathUtils.sigmoid(predicted);
        }
        return target;
    }

    public static float logLoss(float p, float y) {
        BinaryLoss.checkTarget(y);
        float z = y * p;
        if (z > 18.0f) {
            return (float)Math.exp(-z);
        }
        if (z < -18.0f) {
            return -z;
        }
        return (float)Math.log(1.0 + Math.exp(-z));
    }

    public static double logLoss(double p, double y) {
        BinaryLoss.checkTarget(y);
        double z = y * p;
        if (z > 18.0) {
            return Math.exp(-z);
        }
        if (z < -18.0) {
            return -z;
        }
        return Math.log(1.0 + Math.exp(-z));
    }

    public static float squaredLoss(float p, float y) {
        float z = p - y;
        return z * z * 0.5f;
    }

    public static double squaredLoss(double p, double y) {
        double z = p - y;
        return z * z * 0.5;
    }

    public static float hingeLoss(float p, float y, float threshold) {
        BinaryLoss.checkTarget(y);
        float z = y * p;
        return threshold - z;
    }

    public static double hingeLoss(double p, double y, double threshold) {
        BinaryLoss.checkTarget(y);
        double z = y * p;
        return threshold - z;
    }

    public static float hingeLoss(float p, float y) {
        return LossFunctions.hingeLoss(p, y, 1.0f);
    }

    public static double hingeLoss(double p, double y) {
        return LossFunctions.hingeLoss(p, y, 1.0);
    }

    public static float squaredHingeLoss(float p, float y) {
        BinaryLoss.checkTarget(y);
        float z = y * p;
        float d = 1.0f - z;
        return d > 0.0f ? d * d : 0.0f;
    }

    public static double squaredHingeLoss(double p, double y) {
        BinaryLoss.checkTarget(y);
        double z = y * p;
        double d = 1.0 - z;
        return d > 0.0 ? d * d : 0.0;
    }

    public static float epsilonInsensitiveLoss(float predicted, float target, float epsilon) {
        return Math.abs(target - predicted) - epsilon;
    }

    public static final class EpsilonInsensitiveLoss
    extends RegressionLoss {
        private float epsilon;

        public EpsilonInsensitiveLoss() {
            this(0.1f);
        }

        public EpsilonInsensitiveLoss(float epsilon) {
            this.epsilon = epsilon;
        }

        public void setEpsilon(float epsilon) {
            this.epsilon = epsilon;
        }

        @Override
        public float loss(float p, float y) {
            float loss = Math.abs(y - p) - this.epsilon;
            return loss > 0.0f ? loss : 0.0f;
        }

        @Override
        public double loss(double p, double y) {
            double loss = Math.abs(y - p) - (double)this.epsilon;
            return loss > 0.0 ? loss : 0.0;
        }

        @Override
        public float dloss(float p, float y) {
            if (y - p > this.epsilon) {
                return -1.0f;
            }
            if (p - y > this.epsilon) {
                return 1.0f;
            }
            return 0.0f;
        }
    }

    public static final class QuantileLoss
    extends RegressionLoss {
        private float tau;

        public QuantileLoss() {
            this.tau = 0.5f;
        }

        public QuantileLoss(float tau) {
            this.setTau(tau);
        }

        public void setTau(float tau) {
            if (tau <= 0.0f || (double)tau >= 1.0) {
                throw new IllegalArgumentException("tau must be in range (0, 1): " + tau);
            }
            this.tau = tau;
        }

        @Override
        public float loss(float p, float y) {
            float e = y - p;
            if (e > 0.0f) {
                return this.tau * e;
            }
            return -(1.0f - this.tau) * e;
        }

        @Override
        public double loss(double p, double y) {
            double e = y - p;
            if (e > 0.0) {
                return (double)this.tau * e;
            }
            return -(1.0 - (double)this.tau) * e;
        }

        @Override
        public float dloss(float p, float y) {
            float e = y - p;
            if (e == 0.0f) {
                return 0.0f;
            }
            return e > 0.0f ? -this.tau : 1.0f - this.tau;
        }
    }

    public static final class SquaredHingeLoss
    extends BinaryLoss {
        @Override
        public float loss(float p, float y) {
            return LossFunctions.squaredHingeLoss(p, y);
        }

        @Override
        public double loss(double p, double y) {
            return LossFunctions.squaredHingeLoss(p, y);
        }

        @Override
        public float dloss(float p, float y) {
            SquaredHingeLoss.checkTarget(y);
            float d = 1.0f - y * p;
            return d > 0.0f ? -2.0f * d * y : 0.0f;
        }
    }

    public static final class HingeLoss
    extends BinaryLoss {
        private float threshold;

        public HingeLoss() {
            this(1.0f);
        }

        public HingeLoss(float threshold) {
            this.threshold = threshold;
        }

        public void setThreshold(float threshold) {
            this.threshold = threshold;
        }

        @Override
        public float loss(float p, float y) {
            float loss = LossFunctions.hingeLoss(p, y, this.threshold);
            return loss > 0.0f ? loss : 0.0f;
        }

        @Override
        public double loss(double p, double y) {
            double loss = LossFunctions.hingeLoss(p, y, (double)this.threshold);
            return loss > 0.0 ? loss : 0.0;
        }

        @Override
        public float dloss(float p, float y) {
            float loss = LossFunctions.hingeLoss(p, y, this.threshold);
            return loss > 0.0f ? -y : 0.0f;
        }
    }

    public static final class LogLoss
    extends BinaryLoss {
        @Override
        public float loss(float p, float y) {
            LogLoss.checkTarget(y);
            float z = y * p;
            if (z > 18.0f) {
                return (float)Math.exp(-z);
            }
            if (z < -18.0f) {
                return -z;
            }
            return (float)Math.log(1.0 + Math.exp(-z));
        }

        @Override
        public double loss(double p, double y) {
            LogLoss.checkTarget(y);
            double z = y * p;
            if (z > 18.0) {
                return Math.exp(-z);
            }
            if (z < -18.0) {
                return -z;
            }
            return Math.log(1.0 + Math.exp(-z));
        }

        @Override
        public float dloss(float p, float y) {
            LogLoss.checkTarget(y);
            float z = y * p;
            if (z > 18.0f) {
                return (float)Math.exp(-z) * -y;
            }
            if (z < -18.0f) {
                return -y;
            }
            return -y / ((float)Math.exp(z) + 1.0f);
        }
    }

    public static final class SquaredLoss
    extends RegressionLoss {
        @Override
        public float loss(float p, float y) {
            float z = p - y;
            return z * z * 0.5f;
        }

        @Override
        public double loss(double p, double y) {
            double z = p - y;
            return z * z * 0.5;
        }

        @Override
        public float dloss(float p, float y) {
            return p - y;
        }
    }

    public static abstract class RegressionLoss
    implements LossFunction {
        @Override
        public boolean forBinaryClassification() {
            return false;
        }

        @Override
        public boolean forRegression() {
            return true;
        }
    }

    public static abstract class BinaryLoss
    implements LossFunction {
        protected static void checkTarget(float y) {
            if (y != 1.0f && y != -1.0f) {
                throw new IllegalArgumentException("target must be [+1,-1]: " + y);
            }
        }

        protected static void checkTarget(double y) {
            if (y != 1.0 && y != -1.0) {
                throw new IllegalArgumentException("target must be [+1,-1]: " + y);
            }
        }

        @Override
        public boolean forBinaryClassification() {
            return true;
        }

        @Override
        public boolean forRegression() {
            return false;
        }
    }

    public static interface LossFunction {
        public float loss(float var1, float var2);

        public double loss(double var1, double var3);

        public float dloss(float var1, float var2);

        public boolean forBinaryClassification();

        public boolean forRegression();
    }

    public static enum LossType {
        SquaredLoss,
        LogLoss,
        HingeLoss,
        SquaredHingeLoss,
        QuantileLoss,
        EpsilonInsensitiveLoss;

    }
}

