package ca.training.bigdata.spark.ml

import org.apache.spark.SparkContext
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
import org.apache.spark.ml.feature.RegexTokenizer
import org.apache.spark.ml.feature.StopWordsRemover
import org.apache.spark.ml.feature.{StringIndexer, IndexToString, VectorIndexer}
import org.apache.spark.ml.feature._
import org.apache.spark.ml.clustering.KMeans

/**
  * Created by BigDataTraining on 2018-04-01.
  */
object NLPKMeans {

  def process(spark: SparkSession, topic: String) = {
    import spark.implicits._
    val rdd = spark.sparkContext.wholeTextFiles(s"/tmp/20_newsgroups/20_newsgroups/${topic}")
    var df = rdd.toDF("id","text")
    df = df.withColumn("topic", lit(topic))
    df.write.mode("append").format("parquet").save("/root/20_newsgroups")
  }

  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().appName("NLP using KMeans").getOrCreate()
    import spark.implicits._
    val sc = spark.sparkContext

    List("alt.atheism","comp.graphics","comp.os.ms-windows.misc","comp.sys.ibm.pc.hardware",
      "comp.sys.mac.hardware","comp.windows.x","misc.forsale","rec.autos","rec.motorcycles",
      "rec.sport.baseball","rec.sport.hockey","sci.crypt","sci.electronics","sci.med",
      "sci.space","soc.religion.christian","talk.politics.guns",
      "talk.politics.mideast","talk.politics.misc","talk.religion.misc").foreach { topic =>
       process(spark, topic)
    }

    val corpus = spark.read.parquet("/root/20_newsgroups")
    val corpusDF = corpus.drop('label).drop('id).withColumnRenamed("topic", "label")

    val tokenizer = new RegexTokenizer()
      .setPattern("[\\\\W_]+")
      .setMinTokenLength(4)
      .setInputCol("text")
      .setOutputCol("tokens")

    val tokenizedDF = tokenizer.transform(corpusDF)

    val remover = new StopWordsRemover()
      .setInputCol("tokens")
      .setOutputCol("filtered")

    val filteredDF = remover.transform(tokenizedDF)

    val labelIndexer = new StringIndexer()
      .setInputCol("label")
      .setOutputCol("indexedLabel")
      .setHandleInvalid("skip")
      .fit(filteredDF)

    val hashingTF = new HashingTF()
      .setInputCol("filtered")
      .setOutputCol("hashingTF")
      .setNumFeatures(20000)

    val featurizedDataDF = hashingTF
      .transform(filteredDF)

    val idf = new IDF()
      .setInputCol("hashingTF")
      .setOutputCol("idfOutput")

    val idfModel = idf
      .fit(featurizedDataDF)

    val trimmedTFIDFOutput = idfModel
      .transform(featurizedDataDF)
      .drop("text", "tokens", "filtered", "hashingTF")

    val finalDF = new Normalizer()
      .setInputCol("idfOutput")
      .setOutputCol("features")
      .transform(trimmedTFIDFOutput)

    val kmeans = new KMeans()
      .setK(20)
      .setSeed(1L)

    kmeans.setFeaturesCol("features")

    val model = kmeans.fit(finalDF)

    val WSSSE = model.computeCost(finalDF)

    println(s"Within Set Sum of Squared Errors = $WSSSE")

    println("Cluster Centers: ")

    model.clusterCenters.foreach(println)

    val finalWithClusters = model.transform(finalDF).cache()

    finalWithClusters.groupBy('prediction).count

    finalWithClusters.filter('prediction === 1).groupBy('label).count
    finalWithClusters.filter('label === "rec.sport.hockey").groupBy('prediction).count
    finalWithClusters.filter('prediction === 19).groupBy('label).count
    finalWithClusters.filter('label === "comp.sys.ibm.pc.hardware").groupBy('prediction).count


  }

}
