package ca.training.bigdata.spark.ml

import org.apache.spark.sql.SparkSession
import breeze.linalg._
import org.apache.spark.storage.StorageLevel

/**
  * Created by BigDataTraining on 2018-03-10.
  */
object SentenceSimilarity {

  def getWordPairs(id: Long, s1: String, s2: String, stopwords: Set[String]):
  List[(Long, (String, String))] = {
    val w1s = s1.toLowerCase
      .replaceAll("\\p{Punct}", "")
      .split(" ")
      .filter(w => !stopwords.contains(w))
    val w2s = s2.toLowerCase
      .replaceAll("\\p{Punct}", "")
      .split(" ")
      .filter(w => !stopwords.contains(w))
    val wpairs = for (w1 <- w1s; w2 <- w2s) yield (id, (w1, w2))
    wpairs.toList
  }

  case class SentencePair(s1: String, s2: String, wmd: Double)

  def dist(lvec: String, rvec: String): Double = {
    val lv = DenseVector(lvec.split(',').map(_.toDouble))
    val rv = DenseVector(rvec.split(',').map(_.toDouble))
    math.sqrt(sum((lv - rv) :* (lv - rv)))
  }

  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().appName("Fuzzy Matching in Spark with Sentence Similarity Using Word2Vec").getOrCreate()
    import spark.implicits._
    val sc = spark.sparkContext

    val sentencePairs = sc.textFile("/tmp/sentence_pairs.txt")
      .map(line => {
        val Array(s1, s2) = line.split('\t')
        (s1, s2)
      })
      .zipWithIndex
      .persist(StorageLevel.MEMORY_AND_DISK)

    sentencePairs.count()

    val stopwords = sc.textFile("/tmp/stopwords.txt").collect.toSet
    val bStopwords = sc.broadcast(stopwords)

    val wordPairs = sentencePairs.flatMap(ssi =>
      getWordPairs(ssi._2, ssi._1._1, ssi._1._2, bStopwords.value))
    wordPairs.count()

    var words = sc.textFile("file:///root/GoogleNews-vectors-negative300.tsv")

    val header = words.first()
    words = words.filter(row => row != header)

    val w2vs = words.map(line => {
        val values = line.split(' ')
        //val Array(word, vector) = line.split('\t')
        (values.head, values.tail.mkString(","))
      })

    val wordVectors = wordPairs.map({case (idx, (lword, rword)) =>
      (rword, (idx, lword))})
      .join(w2vs)    // (rword, ((idx, lword), rvec))
      .map({case (rword, ((idx, lword), rvec)) => (lword, (idx, rvec))})
      .join(w2vs)    // (lword, ((idx, rvec), lvec))
      .map({case (lword, ((idx, rvec), lvec)) => ((idx, lword), (lvec, rvec))})
      .map({case ((idx, lword), (lvec, rvec)) =>
        ((idx, lword), List(dist(lvec, rvec)))})
      .persist(StorageLevel.MEMORY_AND_DISK)

    val bestWMDs = wordVectors.reduceByKey((a, b) => a ++ b)
      .mapValues(dists => dists.sortWith(_ < _).head)  // dist to closest word
      .map({case ((idx, lword), wmd) => (idx, wmd)})
      .reduceByKey((a, b) => a + b)                    // sum all wmds for sent

    val results = sentencePairs.map(_.swap)
      .join(bestWMDs)
      .map({case (id, ((s1, s2), wmd)) => SentencePair(s1, s2, wmd)})
      .toDF()
      .orderBy($"s1".asc, $"wmd".asc)

    results.show()

  }


}
