package ca.training.bigdata.spark.ml

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.NaiveBayes
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature._
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._

/**
  * Created by BigDataTraining on 2018-03-31.
  */
object AmazonReviewPrediction {

  def udfReviewBins() = udf[Double, Double] { a =>
    val x = a match {
      case 1.0 => 1.0;
      case 2.0 => 1.0;
      case 3.0 => 2.0;
      case 4.0 => 3.0;
      case 5.0 => 3.0;
    };
    x;
  }

  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().appName("Amazon Review Prediction Using Naive Bayes").getOrCreate()
    import spark.implicits._
    val sc = spark.sparkContext

    val inDF = spark.read.json("file:///root/reviews_Electronics_5.json")
    val modifiedInDF = inDF.withColumn("rating", udfReviewBins()($"overall")).drop("overall")
    modifiedInDF.show()
    modifiedInDF.groupBy("rating").count().orderBy("rating").show()
    modifiedInDF.createOrReplaceTempView("modReviewsTable")
    val reviewsDF = spark.sql(
      """
      SELECT text, label, rowNumber FROM (
        SELECT  rating AS label, reviewText AS text, row_number() OVER (PARTITION BY rating ORDER BY rand()) AS rowNumber FROM modReviewsTable
        ) modReviewsTable
      WHERE rowNumber <= 120000
      """)
    reviewsDF.groupBy("label").count().orderBy("label").show()

    val trainingData = reviewsDF
      .filter(reviewsDF("rowNumber") <= 100000).select("text", "label")
    val testData = reviewsDF
      .filter(reviewsDF("rowNumber") > 20000).select("text", "label")

    val regexTokenizer = new RegexTokenizer()
      .setPattern("[a-zA-Z']+")
      .setGaps(false)
      .setInputCol("text")

    val remover = new StopWordsRemover()
      .setInputCol(regexTokenizer.getOutputCol)

    val bigrams = new NGram().setN(2).setInputCol(remover.getOutputCol)
    val trigrams = new NGram().setN(3).setInputCol(remover.getOutputCol)

    val removerHashingTF = new HashingTF().setInputCol(remover.getOutputCol)
    val ngram2HashingTF = new HashingTF().setInputCol(bigrams.getOutputCol)
    val ngram3HashingTF = new HashingTF().setInputCol(trigrams.getOutputCol)

    val assembler = new VectorAssembler()
      .setInputCols(Array(removerHashingTF.getOutputCol, ngram2HashingTF.getOutputCol, ngram3HashingTF.getOutputCol))

    val labelIndexer = new StringIndexer()
      .setInputCol("label")
      .setOutputCol("indexedLabel")
      .fit(reviewsDF)

    val labelConverter = new IndexToString()
      .setInputCol("prediction")
      .setOutputCol("predictedLabel")
      .setLabels(labelIndexer.labels)

    val nb = new NaiveBayes()
      .setLabelCol(labelIndexer.getOutputCol)
      .setFeaturesCol(assembler.getOutputCol)
      .setPredictionCol("prediction")
      .setModelType("multinomial")

    val pipeline = new Pipeline()
      .setStages(Array(regexTokenizer, remover, bigrams, trigrams, removerHashingTF, ngram2HashingTF, ngram3HashingTF, assembler, labelIndexer, nb, labelConverter))

    val paramGrid = new ParamGridBuilder()
      .addGrid(removerHashingTF.numFeatures, Array(1000, 10000))
      .addGrid(ngram2HashingTF.numFeatures, Array(1000, 10000))
      .addGrid(ngram3HashingTF.numFeatures, Array(1000, 10000))
      .build()

    val cv = new CrossValidator()
      .setEstimator(pipeline)
      .setEvaluator(new MulticlassClassificationEvaluator()
      .setLabelCol("indexedLabel")
      .setPredictionCol("prediction")
      .setMetricName("accuracy"))
      .setEstimatorParamMaps(paramGrid).setNumFolds(5)

    val cvModel = cv.fit(trainingData)

    val predictions = cvModel.transform(testData)

    val evaluator = new MulticlassClassificationEvaluator()
      .setLabelCol("indexedLabel")
      .setPredictionCol("prediction")
      .setMetricName("accuracy")

    val accuracy = evaluator.evaluate(predictions)

    println("Test Error = " + (1.0 - accuracy))

  }

}
