package ca.training.bigdata.spark.sql.catalyst

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, EqualTo}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.{Row, SQLContext, SparkSession, Strategy}
import org.apache.spark.sql.sources.{BaseRelation, TableScan}
import org.apache.spark.sql.types.{IntegerType, StructType}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.execution.datasources._

/**
  * Created by BigDataTraining on 2018-03-10.
  */
object SmartSimpleJoin{

  case class OverlappingRangeJoin(leftOutput: Attribute, rightOutput: Attribute, start: Int, end: Int) extends SparkPlan {
    def output: Seq[Attribute] = leftOutput :: rightOutput :: Nil

    def doExecute(): org.apache.spark.rdd.RDD[InternalRow] = {
      sqlContext.sparkContext.parallelize(start to end).map(i => {
        InternalRow(i, i)
      })
    }

    def children: Seq[SparkPlan] = Nil
  }

  case class EmptyJoin(output: Seq[Attribute]) extends SparkPlan {

    def doExecute(): org.apache.spark.rdd.RDD[InternalRow] = {
      sqlContext.sparkContext.emptyRDD
    }

    def children: Seq[SparkPlan] = Nil
  }

  object SmartSimpleJoinStrategy extends Strategy with Serializable {
    def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
      // Find inner joins between two SimpleRelations where the condition is equality.
      case Join(l @ LogicalRelation(left: SimpleRelation, _, _),
                r @ LogicalRelation(right: SimpleRelation, _, _),
                Inner, Some(EqualTo(a, b))) =>
        // Check if the join condition is comparing `a` from each relation.
        if (a == l.output.head && b == r.output.head || a == r.output.head && b == l.output.head) {
          if ((left.start <= right.end) && (left.end >= right.start)) {
            OverlappingRangeJoin(
              l.output.head,
              r.output.head,
              math.max(left.start, right.start),
              math.min(left.end, right.end)) :: Nil
          } else {
            // Ranges don't overlap, join will be empty
            EmptyJoin(l.output.head :: r.output.head :: Nil) :: Nil
          }
        } else {
          Nil
        }
      case _ => Nil
    }
  }

  def main(args: Array[String]): Unit = {
    val spark = SparkSession
      .builder()
      .appName("Smart Simple Join")
      .enableHiveSupport()
      .getOrCreate()
    import spark.implicits._
    val sqlContext = spark.sqlContext

    sqlContext.baseRelationToDataFrame(SimpleRelation(1,1)(sqlContext)).registerTempTable("smallTable")

    sqlContext.baseRelationToDataFrame(SimpleRelation(1, 10000000)(sqlContext)).registerTempTable("bigTable")

    spark.experimental.extraStrategies = SmartSimpleJoinStrategy :: Nil

    val query = spark.sql("SELECT * FROM smallTable s JOIN bigTable b ON s.a = b.a")

    query.explain(extended = true)

    query.collect().foreach(println)

  }

}