Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize and fix Api validation script #466

Merged
merged 2 commits into from
Aug 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion api_validation/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
<groupId>net.alchim31.maven</groupId>
<artifactId>scala-maven-plugin</artifactId>
<configuration>
<mainClass>com.nvidia.spark.api.ApiValidation</mainClass>
<mainClass>com.nvidia.spark.rapids.api.ApiValidation</mainClass>
</configuration>
</plugin>
</plugins>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ object ApiValidation extends Logging {
}

// Method to convert string to TypeTag where string is fully qualified class Name
def stringToTypeTag[A](execName: String): TypeTag[A] = {
val execNameObject = Class.forName(execName) // obtain object from execName
val runTimeMirror = runtimeMirror(execNameObject.getClassLoader) // obtain runtime mirror
val classSym = runTimeMirror.staticClass(execName) // obtain class symbol for `execNameObject`
val tpe = classSym.selfType // obtain type object for `execNameObject`
def classToTypeTag[A](execName: Class[_ <: _]): TypeTag[A] = {
val runTimeMirror = runtimeMirror(execName.getClassLoader) // obtain runtime mirror
val classSym = runTimeMirror.staticClass(execName.getName)
// obtain class symbol for `execNameObject`
val tpe = classSym.selfType // obtain type object for `execNameObject`
// create a type tag which contains above type object
TypeTag(runTimeMirror, new api.TypeCreator {
def apply[U <: api.Universe with Singleton](m: api.Mirror[U]): U#Type =
Expand All @@ -48,7 +48,8 @@ object ApiValidation extends Logging {

val enabledExecs = List (
"[org.apache.spark.sql.execution.joins.SortMergeJoinExec]",
"[org.apache.spark.sql.execution.aggregate.HashAggregateExec]"
"[org.apache.spark.sql.execution.aggregate.HashAggregateExec]",
"[org.apache.spark.sql.execution.CollectLimitExec]"
)

def printHeaders(a: String, appender: StringBuilder): Unit = {
Expand All @@ -68,10 +69,9 @@ object ApiValidation extends Logging {
val gpuKeys = gpuExecs.keys
var printNewline = false

gpuKeys.map(x => x.getName).foreach { e =>

// Get SparkExecs argNames and types
val sparkTypes = stringToTypeTag(e)
gpuKeys.foreach { e =>
// Get SparkExecs argNames and types
val sparkTypes = classToTypeTag(e)

// Proceed only if the Exec is not enabled
if (!enabledExecs.contains(sparkTypes.toString().replace("TypeTag", ""))) {
Expand All @@ -81,18 +81,21 @@ object ApiValidation extends Logging {
// Note that for some there is no 1-1 mapping between names
// Some Execs are in different packages.
val execType = sparkTypes.tpe.toString.split('.').last

val gpu = execType match {
case "BroadcastHashJoinExec" | "BroadcastExchangeExec" =>
s"org.apache.spark.sql.execution.Gpu" + execType
case "FileSourceScanExec" => s"org.apache.spark.sql.rapids.Gpu" + execType
case "SortMergeJoinExec" => s"com.nvidia.spark.rapids.GpuShuffledHashJoinExec"
case "BroadcastExchangeExec" => s"org.apache.spark.sql.rapids.execution.Gpu" + execType
case "BroadcastHashJoinExec" => s"com.nvidia.spark.rapids.shims.spark300.Gpu" + execType
case "FileSourceScanExec" => s"org.apache.spark.sql.rapids.shims.spark300.Gpu" + execType
case "CartesianProductExec" => s"org.apache.spark.sql.rapids.Gpu" + execType
case "BroadcastNestedLoopJoinExec" =>
s"com.nvidia.spark.rapids.shims.spark300.Gpu" + execType
case "SortMergeJoinExec" | "ShuffledHashJoinExec" =>
s"com.nvidia.spark.rapids.shims.spark300.GpuShuffledHashJoinExec"
case "SortAggregateExec" => s"com.nvidia.spark.rapids.GpuHashAggregateExec"
case _ => s"com.nvidia.spark.rapids.Gpu" + execType
}

// TODO: Add error handling if Type is not present
val gpuTypes = stringToTypeTag(gpu)
val gpuTypes = classToTypeTag(Class.forName(gpu))

val sparkToGpuExecMap = Map(
"org.apache.spark.sql.catalyst.expressions.Expression" ->
Expand Down