package ca.training.bigdata.spark.ml

import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.classification.{DecisionTreeClassifier, DecisionTreeClassificationModel}
import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.tuning._

/**
  * Created by BigDataTraining on 2018-04-01.
  */
object HandwrittenDigitRecognition {

  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().appName("Handwritten digit recognition using Decision Tree").getOrCreate()
    import spark.implicits._
    val sc = spark.sparkContext

    val training = spark.sqlContext.read.format("libsvm").load("file:///root/TrainingOnHDP/dataset/spark/mnist-digits-train.txt").cache
    val test = spark.sqlContext.read.format("libsvm").load("file:///root/TrainingOnHDP/dataset/spark/mnist-digits-test.txt").cache

    val indexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel")
    val dtc = new DecisionTreeClassifier().setLabelCol("indexedLabel")
    val pipeline = new Pipeline().setStages(Array(indexer, dtc))
    val model = pipeline.fit(training)

    val variedMaxDepthModels = (0 until 8).map { maxDepth =>
      dtc.setMaxDepth(maxDepth)
      val pipeline = new Pipeline().setStages(Array(indexer, dtc))
      pipeline.fit(training)
    }

    val evaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel")

    val accuracies = (0 until 8).map { maxDepth =>
      val model = variedMaxDepthModels(maxDepth)
      val predictions = model.transform(test)
      (maxDepth, evaluator.evaluate(predictions))
    }.toDF("maxDepth", "accuracy")

    accuracies.show()

    val pipeline1 = new Pipeline().setStages(Array(indexer, dtc))

    val grid1 = new ParamGridBuilder()
      .addGrid(dtc.maxDepth, (4 until 8).toArray)
      .build()

    val evaluator1 = new MulticlassClassificationEvaluator()
      .setLabelCol("indexedLabel")
      .setMetricName("accuracy")

    val cv = new CrossValidator()
      .setNumFolds(3)
      .setEstimator(pipeline1)
      .setEstimatorParamMaps(grid1)
      .setEvaluator(evaluator1)

    val cvModel = cv.fit(training)

    cvModel.getEstimatorParamMaps.zip(cvModel.avgMetrics)
    evaluator1.evaluate(cvModel.transform(test))

    dtc.setMaxDepth(6)
    val accuracies1 = Seq(2, 4, 8, 16, 32).map { case maxBins =>
      dtc.setMaxBins(maxBins)
      val pipeline = new Pipeline().setStages(Array(indexer, dtc))
      val model = pipeline.fit(training)
      val predictions = model.transform(test)
      (maxBins, evaluator.evaluate(predictions))
    }.toDF("maxBins", "accuracy")

    accuracies1.show()

  }


}
