package ca.training.bigdata.spark.ml

import org.apache.spark.sql.SparkSession
import org.apache.spark._
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql._
import org.apache.spark.sql.Dataset
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.DecisionTreeClassifier
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.ml.tuning.ParamGridBuilder
import org.apache.spark.ml.tuning.CrossValidator
import org.apache.spark.ml.feature.VectorAssembler
/**
  * Created by BigDataTraining on 2018-04-01.
  */
object ChurningPrediction {

  case class Account(state: String, len: Integer, acode: String,
                     intlplan: String, vplan: String, numvmail: Double,
                     tdmins: Double, tdcalls: Double, tdcharge: Double,
                     temins: Double, tecalls: Double, techarge: Double,
                     tnmins: Double, tncalls: Double, tncharge: Double,
                     timins: Double, ticalls: Double, ticharge: Double,
                     numcs: Double, churn: String)

  val schema = StructType(Array(
    StructField("state", StringType, true),
    StructField("len", IntegerType, true),
    StructField("acode", StringType, true),
    StructField("intlplan", StringType, true),
    StructField("vplan", StringType, true),
    StructField("numvmail", DoubleType, true),
    StructField("tdmins", DoubleType, true),
    StructField("tdcalls", DoubleType, true),
    StructField("tdcharge", DoubleType, true),
    StructField("temins", DoubleType, true),
    StructField("tecalls", DoubleType, true),
    StructField("techarge", DoubleType, true),
    StructField("tnmins", DoubleType, true),
    StructField("tncalls", DoubleType, true),
    StructField("tncharge", DoubleType, true),
    StructField("timins", DoubleType, true),
    StructField("ticalls", DoubleType, true),
    StructField("ticharge", DoubleType, true),
    StructField("numcs", DoubleType, true),
    StructField("churn", StringType, true)
  ))


  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().appName("Churning Prediction using Decision Tree").getOrCreate()
    import spark.implicits._
    val sc = spark.sparkContext

    val train: Dataset[Account] = spark.read.option("inferSchema", "false")
      .schema(schema).csv("file:///root/TrainingOnHDP/dataset/spark/churn-bigml-80.csv").as[Account]

    val test: Dataset[Account] = spark.read.option("inferSchema", "false")
      .schema(schema).csv("file:///root/TrainingOnHDP/dataset/spark/churn-bigml-20.csv").as[Account]

    train.printSchema()
    train.createOrReplaceTempView("account")

    spark.catalog.cacheTable("account")

    val fractions = Map("False" -> .17, "True" -> 1.0)

    val strain = train.stat.sampleBy("churn", fractions, 36L)

    val ntrain = strain.drop("state").drop("acode").drop("vplan").drop("tdcharge").drop("techarge")

    ntrain.show

    val ipindexer = new StringIndexer()
      .setInputCol("intlplan")
      .setOutputCol("iplanIndex")

    val labelindexer = new StringIndexer()
      .setInputCol("churn")
      .setOutputCol("label")

    val featureCols = Array("len", "iplanIndex", "numvmail", "tdmins", "tdcalls", "temins", "tecalls", "tnmins", "tncalls", "timins", "ticalls", "numcs")

    val assembler = new VectorAssembler()
      .setInputCols(featureCols)
      .setOutputCol("features")

    val dTree = new DecisionTreeClassifier().setLabelCol("label")
      .setFeaturesCol("features")

    val pipeline = new Pipeline()
      .setStages(Array(ipindexer, labelindexer, assembler, dTree))

    val paramGrid = new ParamGridBuilder()
      .addGrid(dTree.maxDepth, Array(2, 3, 4, 5, 6, 7))
      .build()

    val evaluator = new BinaryClassificationEvaluator()
      .setLabelCol("label")
      .setRawPredictionCol("prediction")

    val crossval = new CrossValidator().setEstimator(pipeline)
      .setEvaluator(evaluator)
      .setEstimatorParamMaps(paramGrid)
      .setNumFolds(3)

    val cvModel = crossval.fit(ntrain)

    val bestModel = cvModel.bestModel

    println("The Best Model and Parameters:\n--------------------")

    println(bestModel.asInstanceOf[org.apache.spark.ml.PipelineModel].stages(3))

    bestModel.asInstanceOf[org.apache.spark.ml.PipelineModel]
      .stages(3)
      .extractParamMap

    val treeModel = bestModel.asInstanceOf[org.apache.spark.ml.PipelineModel].stages(3).asInstanceOf[DecisionTreeClassificationModel]

    println("Learned classification tree model:\n" + treeModel.toDebugString)

    val predictions = cvModel.transform(test)
    val accuracy = evaluator.evaluate(predictions)

    evaluator.explainParams()

    val predictionAndLabels = predictions.select("prediction", "label").rdd.map(x =>
      (x(0).asInstanceOf[Double], x(1).asInstanceOf[Double]))

    val metrics = new BinaryClassificationMetrics(predictionAndLabels)

    println("area under the precision-recall curve: " + metrics.areaUnderPR)
    println("area under the receiver operating characteristic (ROC) curve : " + metrics.areaUnderROC)

    println(metrics.fMeasureByThreshold())

    val result = predictions.select("label", "prediction", "probability")
    result.show

    val lp = predictions.select("label", "prediction")
    val counttotal = predictions.count()
    val correct = lp.filter($"label" === $"prediction").count()
    val wrong = lp.filter(not($"label" === $"prediction")).count()
    val ratioWrong = wrong.toDouble / counttotal.toDouble
    val ratioCorrect = correct.toDouble / counttotal.toDouble
    val truep = lp.filter($"prediction" === 0.0).filter($"label" === $"prediction").count() / counttotal.toDouble
    val truen = lp.filter($"prediction" === 1.0).filter($"label" === $"prediction").count() / counttotal.toDouble
    val falsep = lp.filter($"prediction" === 1.0).filter(not($"label" === $"prediction")).count() / counttotal.toDouble
    val falsen = lp.filter($"prediction" === 0.0).filter(not($"label" === $"prediction")).count() / counttotal.toDouble

    println("counttotal : " + counttotal)
    println("correct : " + correct)
    println("wrong: " + wrong)
    println("ratio wrong: " + ratioWrong)
    println("ratio correct: " + ratioCorrect)
    println("ratio true positive : " + truep)
    println("ratio false positive : " + falsep)
    println("ratio true negative : " + truen)
    println("ratio false negative : " + falsen)

    println("wrong: " + wrong)

    val equalp = predictions.selectExpr(
      "double(round(prediction)) as prediction", "label",
      """CASE double(round(prediction)) = label WHEN true then 1 ELSE 0 END as equal"""
    )
    equalp.show


  }
}
