Skip to content

Commit

Permalink
Optimize and fix Api validation script (NVIDIA#466)
Browse files Browse the repository at this point in the history
* Optimize  script and  fix script

Signed-off-by: Niranjan Artal <nartal@nvidia.com>

* remove extra whitespaces
  • Loading branch information
nartal1 authored Aug 3, 2020
1 parent c5e3058 commit 0446381
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 17 deletions.
2 changes: 1 addition & 1 deletion api_validation/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
<groupId>net.alchim31.maven</groupId>
<artifactId>scala-maven-plugin</artifactId>
<configuration>
<mainClass>com.nvidia.spark.api.ApiValidation</mainClass>
<mainClass>com.nvidia.spark.rapids.api.ApiValidation</mainClass>
</configuration>
</plugin>
</plugins>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ object ApiValidation extends Logging {
}

// Method to convert string to TypeTag where string is fully qualified class Name
def stringToTypeTag[A](execName: String): TypeTag[A] = {
val execNameObject = Class.forName(execName) // obtain object from execName
val runTimeMirror = runtimeMirror(execNameObject.getClassLoader) // obtain runtime mirror
val classSym = runTimeMirror.staticClass(execName) // obtain class symbol for `execNameObject`
val tpe = classSym.selfType // obtain type object for `execNameObject`
def classToTypeTag[A](execName: Class[_ <: _]): TypeTag[A] = {
val runTimeMirror = runtimeMirror(execName.getClassLoader) // obtain runtime mirror
val classSym = runTimeMirror.staticClass(execName.getName)
// obtain class symbol for `execNameObject`
val tpe = classSym.selfType // obtain type object for `execNameObject`
// create a type tag which contains above type object
TypeTag(runTimeMirror, new api.TypeCreator {
def apply[U <: api.Universe with Singleton](m: api.Mirror[U]): U#Type =
Expand All @@ -48,7 +48,8 @@ object ApiValidation extends Logging {

val enabledExecs = List (
"[org.apache.spark.sql.execution.joins.SortMergeJoinExec]",
"[org.apache.spark.sql.execution.aggregate.HashAggregateExec]"
"[org.apache.spark.sql.execution.aggregate.HashAggregateExec]",
"[org.apache.spark.sql.execution.CollectLimitExec]"
)

def printHeaders(a: String, appender: StringBuilder): Unit = {
Expand All @@ -68,10 +69,9 @@ object ApiValidation extends Logging {
val gpuKeys = gpuExecs.keys
var printNewline = false

gpuKeys.map(x => x.getName).foreach { e =>

// Get SparkExecs argNames and types
val sparkTypes = stringToTypeTag(e)
gpuKeys.foreach { e =>
// Get SparkExecs argNames and types
val sparkTypes = classToTypeTag(e)

// Proceed only if the Exec is not enabled
if (!enabledExecs.contains(sparkTypes.toString().replace("TypeTag", ""))) {
Expand All @@ -81,18 +81,21 @@ object ApiValidation extends Logging {
// Note that for some there is no 1-1 mapping between names
// Some Execs are in different packages.
val execType = sparkTypes.tpe.toString.split('.').last

val gpu = execType match {
case "BroadcastHashJoinExec" | "BroadcastExchangeExec" =>
s"org.apache.spark.sql.execution.Gpu" + execType
case "FileSourceScanExec" => s"org.apache.spark.sql.rapids.Gpu" + execType
case "SortMergeJoinExec" => s"com.nvidia.spark.rapids.GpuShuffledHashJoinExec"
case "BroadcastExchangeExec" => s"org.apache.spark.sql.rapids.execution.Gpu" + execType
case "BroadcastHashJoinExec" => s"com.nvidia.spark.rapids.shims.spark300.Gpu" + execType
case "FileSourceScanExec" => s"org.apache.spark.sql.rapids.shims.spark300.Gpu" + execType
case "CartesianProductExec" => s"org.apache.spark.sql.rapids.Gpu" + execType
case "BroadcastNestedLoopJoinExec" =>
s"com.nvidia.spark.rapids.shims.spark300.Gpu" + execType
case "SortMergeJoinExec" | "ShuffledHashJoinExec" =>
s"com.nvidia.spark.rapids.shims.spark300.GpuShuffledHashJoinExec"
case "SortAggregateExec" => s"com.nvidia.spark.rapids.GpuHashAggregateExec"
case _ => s"com.nvidia.spark.rapids.Gpu" + execType
}

// TODO: Add error handling if Type is not present
val gpuTypes = stringToTypeTag(gpu)
val gpuTypes = classToTypeTag(Class.forName(gpu))

val sparkToGpuExecMap = Map(
"org.apache.spark.sql.catalyst.expressions.Expression" ->
Expand Down

0 comments on commit 0446381

Please sign in to comment.