package ca.training.bigdata.spark.ml

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}
import org.apache.spark.ml.linalg.Vectors

/**
  * Created by BigDataTraining on 2018-03-10.
  */
object WeatherPrediction {

  case class DayWeather(CET: String,
                        Max_TemperatureC: Double,
                        Mean_TemperatureC: Double,
                        Min_TemperatureC: Double,
                        Dew_PointC: Double,
                        MeanDew_PointC: Double,
                        Min_DewpointC: Double,
                        Max_Humidity: Double,
                        Mean_Humidity: Double,
                        Min_Humidity: Double,
                        Max_Sea_Level_PressurehPa: Double,
                        Mean_Sea_Leve_PressurehPa: Double,
                        Min_Sea_Level_PressurehPa: Double,
                        Max_VisibilityKm: Double,
                        Mean_VisibilityKm: Double,
                        Min_VisibilitykM: Double,
                        Max_Wind_SpeedKmph: Double,
                        Mean_Wind_SpeedKmph: Double,
                        Max_Gust_SpeedKmph: Double,
                        Precipitationmm: Double,
                        Events: String)

  case class HouseholdEPC(date: String,
                          time: String,
                          gap: Double,
                          grp: Double,
                          voltage: Double,
                          gi: Double,
                          sm_1: Double,
                          sm_2: Double,
                          sm_3: Double)

  case class HouseholdEPCDTmDay(date: String,
                                day: String,
                                month: String,
                                year: String,
                                dgap: Double,
                                dgrp: Double,
                                dvoltage: Double,
                                dgi: Double,
                                dsm_1: Double,
                                dsm_2: Double,
                                dsm_3: Double)

  def processRdd(data: RDD[String]): RDD[DayWeather] = {

    val rdd = data.map(_.split("\t")).map(c => c.map(f => f match {
      case x if x.isEmpty() || x.equals("-") => "0";
      case x => x }))
      .map { p => DayWeather(
        p(0).trim().toString,
        p(1).toDouble,
        p(2).toDouble,
        p(3).toDouble,
        p(4).toDouble,
        p(5).toDouble,
        p(6).toDouble,
        p(7).toDouble,
        p(8).toDouble,
        p(9).toDouble,
        p(10).toDouble,
        p(11).toDouble,
        p(12).toDouble,
        p(13).toDouble,
        p(14).toDouble,
        p(15).toDouble,
        p(16).toDouble,
        p(17).toDouble,
        p(18).toDouble,
        p(19).toDouble,
        p(20));
      }
    rdd;
  }

  def containsSubstring( str:String, substr:String): Double = {
    if (str.contains(substr)) 1 else 0
  }

  def udfContains(substr: String) = udf((x: String) => containsSubstring(x, substr))

  def udfVec() = udf[org.apache.spark.ml.linalg.Vector, String, Int, Double, Double, Double] {
    (a, b, c, d, e) => val x = a match {
      case "Monday" => 1;
      case "Tuesday" => 2;
      case "Wednesday" => 3;
      case "Thursday" => 4;
      case "Friday" => 5;
      case "Saturday" => 6;
      case "Sunday" => 7; };
      Vectors.dense(x, b, c, d, e);
  }


  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().appName("Weather Prediction Using Random Forest").getOrCreate()
    import spark.implicits._
    val sc = spark.sparkContext

    var rdd1 = sc.textFile("file:///root/household_power_consumption.txt")

    val header = rdd1.first()
    val data1 = rdd1.filter(row => row != header).filter(rows => !rows.contains("?"))

    val hhEPCClassRdd = data1.map(_.split(";")).map(p => HouseholdEPC(p(0).trim().toString,p(1).trim().toString,p(2).toDouble,p(3).toDouble,p(4).toDouble,p(5).toDouble,p(6).toDouble,p(7).toDouble,p(8).toDouble))
    val hhEPCDF = hhEPCClassRdd.toDF()
    hhEPCDF.show(5)

    val hhEPCDatesDf = hhEPCDF
      .withColumn("dow", from_unixtime(unix_timestamp($"date", "dd/MM/yyyy"), "EEEEE"))
      .withColumn("day", dayofmonth(to_date(unix_timestamp($"date", "dd/MM/yyyy").cast("timestamp"))))
      .withColumn("month", month(to_date(unix_timestamp($"date", "dd/MM/yyyy").cast("timestamp"))))
      .withColumn("year", year(to_date(unix_timestamp($"date", "dd/MM/yyyy").cast("timestamp"))))

    hhEPCDatesDf.show(5)

    val delTmDF = hhEPCDF.drop("time")

    val finalDayDf1 = delTmDF
      .groupBy($"date")
      .agg(sum($"gap").name("A"),sum($"grp").name("B"),avg($"voltage").name("C"),sum($"gi").name("D"), sum($"sm_1").name("E"), sum($"sm_2").name("F"), sum($"sm_3").name("G"))
      .select($"date", round($"A", 2).name("dgap"), round($"B", 2).name("dgrp"), round($"C", 2).name("dvoltage"), round($"C", 2).name("dgi"), round($"E", 2).name("dsm_1"), round($"F", 2).name("dsm_2"), round($"G", 2).name("dsm_3"))
      .withColumn("day", dayofmonth(to_date(unix_timestamp($"date", "dd/MM/yyyy").cast("timestamp"))))
      .withColumn("month", month(to_date(unix_timestamp($"date", "dd/MM/yyyy").cast("timestamp"))))
      .withColumn("year", year(to_date(unix_timestamp($"date", "dd/MM/yyyy").cast("timestamp"))))

    val ds1 = finalDayDf1.as[HouseholdEPCDTmDay]

    val rdd2 = sc.textFile("file:////root/TrainingOnHDP/dataset/spark/weather_201701.txt")
    val header2 = rdd2.first()
    val data2 = rdd2.filter(row => row != header2)
    val dw_rdd = processRdd(data2)
    val ds2 = dw_rdd.toDF()
      .na.replace(Seq("CET", "Events"),Map("0" -> "NA")).as[DayWeather]


    val joined_ds = ds1.join(ds2).where(ds1("date") === ds2("CET"))

    val joined_dow_ds = joined_ds
      .withColumn("dow", from_unixtime(unix_timestamp($"date", "dd/MM/yyyy"), "EEEEE"))

    val joined_rained_ds = joined_dow_ds
      .withColumn("label", udfContains("Rain")($"Events"))
      .withColumn("features", udfVec()($"dow", $"month", $"dsm_1", $"dsm_2", $"dsm_3"))

    val labelIndexer = new StringIndexer()
      .setInputCol("label")
      .setOutputCol("indexedLabel")
      .fit(joined_rained_ds)

    val featureIndexer = new VectorIndexer()
      .setInputCol("features")
      .setOutputCol("indexedFeatures")
      .setMaxCategories(7)
      .fit(joined_rained_ds)

    val Array(trainingData, testData) = joined_rained_ds.randomSplit(Array(0.7, 0.3))

    //Code for Creating and running a machine learning pipeline section
    val rf = new RandomForestClassifier()
      .setLabelCol("indexedLabel")
      .setFeaturesCol("indexedFeatures")
      .setNumTrees(10)

    // Convert indexed labels back to original labels.
    val labelConverter = new IndexToString()
      .setInputCol("prediction")
      .setOutputCol("predictedLabel")
      .setLabels(labelIndexer.labels)

    // Chain indexers and forest in a Pipeline.
    val pipeline = new Pipeline()
      .setStages(Array(labelIndexer, featureIndexer, rf, labelConverter))

    // Train model. This also runs the indexers.
    val model = pipeline.fit(trainingData)

    // Make predictions.
    val predictions = model.transform(testData)

    predictions.show(5)

    // Select example rows to display.
    predictions.select("predictedLabel", "label", "features").show(5)

    // Select (prediction, true label) and compute test error.
    val evaluator = new MulticlassClassificationEvaluator()
      .setLabelCol("indexedLabel")
      .setPredictionCol("prediction")
      .setMetricName("accuracy")

    val accuracy = evaluator.evaluate(predictions)

    println("Test Error = " + (1.0 - accuracy))

    val rfModel = model.stages(2).asInstanceOf[RandomForestClassificationModel]

    println("Learned classification forest model:\n" + rfModel.toDebugString)

  }

}
