Skip to content

Commit

Permalink
Resolve comments
Browse files Browse the repository at this point in the history
- Replace println with Logging methods
- Add Nvtx range to the compilation
  • Loading branch information
wjxiz1992 committed Aug 6, 2020
1 parent 1971bfc commit c4eee7e
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 35 deletions.
8 changes: 4 additions & 4 deletions udf-compiler/src/main/scala/com/nvidia/spark/udf/CFG.scala
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,11 @@ case class BB(instructionTable: SortedMap[Int, Instruction]) extends Logging {
Opcode.IF_ICMPGE | Opcode.IF_ICMPGT | Opcode.IF_ICMPLE |
Opcode.IFLT | Opcode.IFLE | Opcode.IFGT | Opcode.IFGE |
Opcode.IFEQ | Opcode.IFNE | Opcode.IFNULL | Opcode.IFNONNULL => {
println(s"[BB.propagateState] lastInstruction: ${lastInstruction.instructionStr}")
logTrace(s"[BB.propagateState] lastInstruction: ${lastInstruction.instructionStr}")

// An if statement has both a false and a true successor
val (0, falseSucc)::(1, trueSucc)::Nil = cfg.successor(this)
println(s"[BB.propagateState] falseSucc ${falseSucc} trueSuccc ${trueSucc}")
logTrace(s"[BB.propagateState] falseSucc ${falseSucc} trueSuccc ${trueSucc}")

// cond is the entry condition into the condition block, and expr is the
// actual condition for IF* (see Instruction.ifOp).
Expand All @@ -78,15 +78,15 @@ case class BB(instructionTable: SortedMap[Int, Instruction]) extends Logging {
val falseState = state.copy(cond = simplify(And(cond, Not(expr.get))))
val trueState = state.copy(cond = simplify(And(cond, expr.get)))

println(s"[BB.propagateState] States before: ${states}")
logDebug(s"[BB.propagateState] States before: ${states}")

// Each successor may already have the state populated if it has
// multiple predecessors.
// Update the states by merging the new state with the existing state.
val newStates = (states
+ (falseSucc -> falseState.merge(states.get(falseSucc)))
+ (trueSucc -> trueState.merge(states.get(trueSucc))))
println(s"[BB.propagateState] States after: ${newStates}")
logDebug(s"[BB.propagateState] States after: ${newStates}")
newStates
}
case Opcode.TABLESWITCH | Opcode.LOOKUPSWITCH =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,15 @@

package com.nvidia.spark.udf



import scala.annotation.tailrec

import javassist.CtClass

import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._

import javassist.CtClass

/**
* CatalystExpressionBuilder
*
Expand Down Expand Up @@ -90,9 +88,9 @@ case class CatalystExpressionBuilder(private val function: AnyRef) extends Loggi
}}

if (compiled == None){
println(s"[CatalystExpressionBuilder] failed to compile")
logDebug(s"[CatalystExpressionBuilder] failed to compile")
} else {
println(s"[CatalystExpressionBuilder] compiled expression: ${compiled.get.toString}")
logDebug(s"[CatalystExpressionBuilder] compiled expression: ${compiled.get.toString}")
}

compiled
Expand Down Expand Up @@ -134,11 +132,14 @@ case class CatalystExpressionBuilder(private val function: AnyRef) extends Loggi
* Pick the first block, and store the rest of the list in [[rest]].
*
* 1) Initially, [[worklist] is [[CFG.head]] :: nil
* 2) As we recurse, [[worklist]] gets new [[BB]]s when the all of its predecessors are visited.
* 2) As we recurse, [[worklist]] gets new [[BB]]s when the all of its predecessors are
* visited.
* 3) The head [[BB]] ([[basicBlock]]), then goes through the compilation process:
* i) [[State]] is obtained (at the beginning, there's a seed [[State]] added in [[compile]]
* ii) after each iteration, new [[State]] is created for [[basicBlock]]. This is the first step where we take
* javaassist Opcode foreach [[Instruction]] in the [[BB]]'s instruction table, and turn it into [[State]]
* ii) after each iteration, new [[State]] is created for [[basicBlock]]. This is the
* first step where we take
* javaassist Opcode foreach [[Instruction]] in the [[BB]]'s instruction table, and turn
* it into [[State]]
* objects with: locals, stack, condition, and an evolving catalyst expression.
* ii) the state is then propagated:
*
Expand Down Expand Up @@ -238,7 +239,7 @@ case class CatalystExpressionBuilder(private val function: AnyRef) extends Loggi
* simplify a directly translated catalyst expression (from bytecode) into something simpler
* that the remaining catalyst optimizations can handle.
*/
object CatalystExpressionBuilder {
object CatalystExpressionBuilder extends Logging {
/** simplify: given a raw converted catalyst expression, attempt to match patterns to simplify
* before handing it over to catalyst optimizers (the LogicalPlan does this later).
*
Expand Down Expand Up @@ -416,7 +417,7 @@ object CatalystExpressionBuilder {
simplifyExpr(Cast(f, BooleanType, tz))))
case _ => expr
}
println(s"[CatalystExpressionBuilder] simplify: ${expr} ==> ${res}")
logDebug(s"[CatalystExpressionBuilder] simplify: ${expr} ==> ${res}")
res
}
val simplifiedExpr = simplifyExpr(expr)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,17 @@
package com.nvidia.spark.udf

import com.nvidia.spark.rapids.RapidsConf

import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types.DataType

case class GpuScalaUDFLogical(udf: ScalaUDF) extends Expression {
case class GpuScalaUDFLogical(udf: ScalaUDF) extends Expression with Logging {
override def nullable: Boolean = udf.nullable

1
override def eval(input: InternalRow): Any = {
}

Expand All @@ -37,7 +39,6 @@ case class GpuScalaUDFLogical(udf: ScalaUDF) extends Expression {

def compile(isTestEnabled: Boolean): Expression = {
// call the compiler
// just an example
try {
val expr = CatalystExpressionBuilder(udf.function).compile(udf.children)
if (expr.isDefined) {
Expand All @@ -47,7 +48,7 @@ case class GpuScalaUDFLogical(udf: ScalaUDF) extends Expression {
}
} catch {
case e: SparkException =>
System.err.println("UDF compilation failure: " + e)
logError("UDF compilation failure: " + e)
if (isTestEnabled) {
throw e
}
Expand Down
26 changes: 15 additions & 11 deletions udf-compiler/src/main/scala/com/nvidia/spark/udf/Instruction.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,18 @@

package com.nvidia.spark.udf

import javassist.bytecode.{CodeIterator, Opcode}
import CatalystExpressionBuilder.simplify
import java.nio.charset.Charset

import javassist.bytecode.{CodeIterator, Opcode}
import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.expressions._

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._

import CatalystExpressionBuilder.simplify


private object Repr {
abstract class CompilerInternal(name: String) extends Expression {
Expand Down Expand Up @@ -71,11 +75,11 @@ case class StringBuilder() extends CompilerInternal("java.lang.StringBuilder") {
}

/**
*
* @param opcode
* @param operand
*/
case class Instruction(opcode: Int, operand: Int, instructionStr: String) {
*
* @param opcode
* @param operand
*/
case class Instruction(opcode: Int, operand: Int, instructionStr: String) extends Logging{
def makeState(lambdaReflection: LambdaReflection, basicBlock: BB, state: State): State = {
val st = opcode match {
case Opcode.ALOAD_0 | Opcode.DLOAD_0 | Opcode.FLOAD_0 |
Expand Down Expand Up @@ -175,7 +179,7 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) {
})
case _ => throw new SparkException("Unsupported instruction: " + instructionStr)
}
println(s"[Instruction] ${instructionStr} got new state: ${st} from state: ${state}")
logDebug(s"[Instruction] ${instructionStr} got new state: ${st} from state: ${state}")
st
}

Expand Down Expand Up @@ -515,8 +519,8 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) {
}

/**
* Ultimately, every opcode will have to be covered here.
*/
* Ultimately, every opcode will have to be covered here.
*/
object Instruction {
def apply(codeIterator: CodeIterator, offset: Int, instructionStr: String): Instruction = {
val opcode: Int = codeIterator.byteAt(offset)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,7 @@ object LambdaReflection {
val serializedLambda = writeReplace.invoke(function)
.asInstanceOf[SerializedLambda]

// ClassPool: javassist TODO: watch for this object getting huge
// http://www.javassist.org/html/javassist/ClassPool.html
val classPool = ClassPool.getDefault
val classPool = new ClassPool(true)
LambdaReflection(classPool, serializedLambda)
}
}
Expand Down
12 changes: 10 additions & 2 deletions udf-compiler/src/main/scala/com/nvidia/spark/udf/Plugin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.nvidia.spark.udf

import ai.rapids.cudf.{NvtxColor, NvtxRange}
import com.nvidia.spark.rapids.RapidsConf

import org.apache.spark.internal.Logging
Expand All @@ -34,7 +35,15 @@ class Plugin extends Function1[SparkSessionExtensions, Unit] with Logging {

case class LogicalPlanRules() extends Rule[LogicalPlan] with Logging {
def replacePartialFunc(plan: LogicalPlan): PartialFunction[Expression, Expression] = {
case d: Expression => attemptToReplaceExpression(plan, d)
case d: Expression =>
{
val nvtx = new NvtxRange("replace UDF", NvtxColor.BLUE)
try {
attemptToReplaceExpression(plan, d)
} finally {
nvtx.close()
}
}
}

def attemptToReplaceExpression(plan: LogicalPlan, exp: Expression): Expression = {
Expand All @@ -61,7 +70,6 @@ case class LogicalPlanRules() extends Rule[LogicalPlan] with Logging {
}
} catch {
case npe: NullPointerException => {
println("npe... never mind then")
exp
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class OpcodeSuite extends FunSuite {

val conf: SparkConf = new SparkConf()
.set("spark.sql.extensions", "com.nvidia.spark.udf.Plugin")
.set("spark.rapids.sql.udfCompiler.enabled", "true")
.set(RapidsConf.EXPLAIN.key, "true")

val spark: SparkSession =
Expand Down

0 comments on commit c4eee7e

Please sign in to comment.