package ca.training.bigdata.spark.ml

import org.apache.spark.ml.feature.{RegexTokenizer, StopWordsRemover}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.{Pipeline, UnaryTransformer}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DataType, DataTypes, StringType}

import scala.util.matching.Regex

class TablesNHTMLElemCleaner(override val uid: String) extends UnaryTransformer[String, String, TablesNHTMLElemCleaner] {
  def this() = this(Identifiable.randomUID("cleaner"))

  override protected def createTransformFunc: String => String = {
    deleteTablesNHTMLElem _
  }

  def deleteTablesNHTMLElem(instr: String): String = {
    val pattern1 = new Regex("(?s)(?i)<Table.*?</Table>")
    val str1 = pattern1.replaceAllIn(instr, " ")
    val pattern2 = new Regex("(?s)<[^>]*>")
    val str2 = pattern2.replaceAllIn(str1, " ")
    str2
  }

  override protected def validateInputType(inputType: DataType): Unit = {
    require(inputType == StringType)
  }

  override protected def outputDataType: DataType = DataTypes.StringType
}

class AllURLsFileNamesDigitsPunctuationExceptPeriodCleaner(override val uid: String) extends UnaryTransformer[String, String, AllURLsFileNamesDigitsPunctuationExceptPeriodCleaner] {
  def this() = this(Identifiable.randomUID("cleaner"))

  override protected def createTransformFunc: String => String = {
    deleteAllURLsFileNamesDigitsPunctuationExceptPeriod _
  }

  def deleteAllURLsFileNamesDigitsPunctuationExceptPeriod(instr: String): String = {
    val pattern1 = new Regex("\\b(https?|ftp|file)://[-a-zA-Z0-9+&@#/%?=~_|!:,.;]*[-a-zA-Z0-9+&@#/%=~_|]")
    val str1 = pattern1.replaceAllIn(instr, "")
    val pattern2 = new Regex("[_a-zA-Z0-9\\-\\.]+.(txt|sgml|xml|xsd|htm|html)")
    val str2 = pattern2.replaceAllIn(str1, " ")
    val pattern3 = new Regex("[^a-zA-Z|^.]")
    val str3 = pattern3.replaceAllIn(str2, " ")
    str3
  }

  override protected def validateInputType(inputType: DataType): Unit = {
    require(inputType == StringType)
  }

  override protected def outputDataType: DataType = DataTypes.StringType
}

class OnlyAlphasCleaner(override val uid: String) extends UnaryTransformer[String, String, OnlyAlphasCleaner] {
  def this() = this(Identifiable.randomUID("cleaner"))

  override protected def createTransformFunc: String => String = {
    keepOnlyAlphas _
  }

  def keepOnlyAlphas(instr: String): String = {
    val pattern1 = new Regex("[^a-zA-Z|]")
    val str1 = pattern1.replaceAllIn(instr, " ")
    val str2 = str1.replaceAll("[\\s]+", " ")
    str2
  }

  override protected def validateInputType(inputType: DataType): Unit = {
    require(inputType == StringType)
  }

  override protected def outputDataType: DataType = DataTypes.StringType
}

class ExcessLFCRWSCleaner(override val uid: String) extends UnaryTransformer[String, String, ExcessLFCRWSCleaner] {
  def this() = this(Identifiable.randomUID("cleaner"))

  override protected def createTransformFunc: String => String = {
    deleteExcessLFCRWS _
  }

  def deleteExcessLFCRWS(instr: String): String = {
    val pattern1 = new Regex("[\n\r]+")
    val str1 = pattern1.replaceAllIn(instr, "\n")
    val pattern2 = new Regex("[\t]+")
    val str2 = pattern2.replaceAllIn(str1, " ")
    val pattern3 = new Regex("\\s+")
    val str3 = pattern3.replaceAllIn(str2, " ")
    str3
  }

  override protected def validateInputType(inputType: DataType): Unit = {
    require(inputType == StringType)
  }

  override protected def outputDataType: DataType = DataTypes.StringType
}

/**
  * Created by BigDataTraining on 2018-03-30.
  */
object TextualAnalysisPipeline {

  def main(args: Array[String]): Unit = {

    val spark = SparkSession.builder().appName("Machine learning application for textual analysis").getOrCreate()
    val sc = spark.sparkContext
    import spark.implicits._

    val linesDF1 = sc.textFile("file:///root/reut2-020.sgm")
      .toDF()

    val tablesNHTMLElemCleaner = new TablesNHTMLElemCleaner()
      .setInputCol("value")
      .setOutputCol("tablesNHTMLElemCleaned")

    val allURLsFileNamesDigitsPunctuationExceptPeriodCleaner = new AllURLsFileNamesDigitsPunctuationExceptPeriodCleaner()
      .setInputCol("tablesNHTMLElemCleaned")
      .setOutputCol("allURLsFileNamesDigitsPunctuationExceptPeriodCleaned")

    val onlyAlphasCleaner = new OnlyAlphasCleaner()
      .setInputCol("allURLsFileNamesDigitsPunctuationExceptPeriodCleaned")
      .setOutputCol("text")

    val excessLFCRWSCleaner = new ExcessLFCRWSCleaner()
      .setInputCol("text")
      .setOutputCol("cleaned")

    val tokenizer = new RegexTokenizer()
      .setInputCol("cleaned")
      .setOutputCol("words")
      .setPattern("\\W")

    val stopwords: Array[String] = sc.textFile("file:////root/TrainingOnHDP/dataset/spark/StopWords_GenericLong.txt").flatMap(_.stripMargin.split("\\s+")).collect

    val remover = new StopWordsRemover()
      .setStopWords(stopwords)
      .setCaseSensitive(false)
      .setInputCol("words")
      .setOutputCol("filtered")

    val pipeline = new Pipeline()
      .setStages(Array(tablesNHTMLElemCleaner,
        allURLsFileNamesDigitsPunctuationExceptPeriodCleaner,
        onlyAlphasCleaner,
        excessLFCRWSCleaner,
        tokenizer,
        remover))

    val model = pipeline.fit(linesDF1)

    val cleanedDF = model
      .transform(linesDF1)
      .drop("value")
      .drop("tablesNHTMLElemCleaned")
      .drop("excessLFCRWSCleaned")
      .drop("allURLsFileNamesDigitsPunctuationExceptPeriodCleaned")
      .drop("text")
      .drop("word")

    val finalDF = cleanedDF.filter(($"cleaned" =!= "") && ($"cleaned" =!= " "))

    val wordsInStoryDF = finalDF
      .withColumn("wordsInStory", explode(split($"cleaned", "[\\s]")))
      .drop("cleaned")

    val dictDF = spark.read.format("csv").option("header", "true")
      .load("file:///root/LoughranMcDonald_MasterDictionary_2014.csv")

    val joinWordsDict = wordsInStoryDF.join(dictDF, lower(wordsInStoryDF("wordsInStory")) === lower(dictDF("Word")))
    wordsInStoryDF.count()
    val numWords = joinWordsDict.count().toDouble
    joinWordsDict.select("wordsInStory").show()
    val negWordCount = joinWordsDict.select("wordsInStory", "negative").where(joinWordsDict("negative") > 0).count()
    val sentiment = negWordCount / (numWords.toDouble)
    val modalWordCount = joinWordsDict.select("wordsInStory", "modal").where(joinWordsDict("modal") > 0).groupBy("modal").count()

    modalWordCount.show()

    val linesDF2 = sc.textFile("file:///root/reut2-008.sgm").toDF()
    val cleanedDF2 = model
      .transform(linesDF2)
      .drop("value")
      .drop("tablesNHTMLElemCleaned")
      .drop("excessLFCRWSCleaned")
      .drop("allURLsFileNamesDigitsPunctuationExceptPeriodCleaned")
      .drop("text")
      .drop("word")

    val finalDF2 = cleanedDF2.filter(($"cleaned" =!= "") && ($"cleaned" =!= " "))

    val wordsInStoryDF2 = finalDF2
      .withColumn("wordsInStory", explode(split($"cleaned", "[\\s]")))
      .drop("cleaned")

    val joinWordsDict2 = wordsInStoryDF2
      .join(dictDF, lower(wordsInStoryDF2("wordsInStory")) === lower(dictDF("Word")))

    val numWords2 = joinWordsDict2.count().toDouble
    joinWordsDict2.select("wordsInStory").show()

    val negWordCount2 = joinWordsDict2.select("wordsInStory", "negative").where(joinWordsDict2("negative") > 0).count()
    val sentiment2 = negWordCount2 / (numWords2.toDouble)
    val modalWordCount2 = joinWordsDict2.select("wordsInStory", "modal")
      .where(joinWordsDict("modal") > 0).groupBy("modal").count()

    modalWordCount2.show()

  }
}
