diff --git a/docs/compatibility.md b/docs/compatibility.md index ff2889ec9ab..d7b2f5e2182 100644 --- a/docs/compatibility.md +++ b/docs/compatibility.md @@ -277,5 +277,73 @@ However, Spark may produce different results for a compiled udf and the non-comp When translating UDFs to Catalyst expressions, the supported UDF functions are limited: -| Operand type | Operation | -| ------------------------------------------------------------------- | ------------------| +| Operand type | Operation | +| -------------------------| ---------------------------------------------------------| +| Arithmetic Unary | +x | +| | -x | +| Arithmetic Binary | lhs + rhs | +| | lhs - rhs | +| | lhs * rhs | +| | lhs / rhs | +| | lhs % rhs | +| Logical | lhs && rhs | +| | lhs || rhs | +| | !x | +| Equality and Relational | lhs == rhs | +| | lhs < rhs | +| | lhs <= rhs | +| | lhs > rhs | +| | lhs >= rhs | +| Bitwise | lhs & rhs | +| | lhs | rhs | +| | lhs ^ rhs | +| | ~x | +| | lhs << rhs | +| | lhs >> rhs | +| | lhs >>> rhs | +| Conditional | if | +| | case | +| Math | abs(x) | +| | cos(x) | +| | acos(x) | +| | asin(x) | +| | tan(x) | +| | atan(x) | +| | tanh(x) | +| | cosh(x) | +| | ceil(x) | +| | floor(x) | +| | exp(x) | +| | log(x) | +| | log10(x) | +| | sqrt(x) | +| Type Cast | * | +| String | lhs + rhs | +| | lhs.equalsIgnoreCase(String rhs) | +| | x.toUpperCase() | +| | x.trim() | +| | x.substring(int begin) | +| | x.substring(int begin, int end) | +| | x.replace(char oldChar, char newChar) | +| | x.replace(CharSequence target, CharSequence replacement) | +| | x.startsWith(String prefix) | +| | lhs.equals(Object rhs) | +| | x.toLowerCase() | +| | x.length() | +| | x.endsWith(String suffix) | +| | lhs.concat(String rhs) | +| | x.isEmpty() | +| | String.valueOf(boolean b) | +| | String.valueOf(char c) | +| | String.valueOf(double d) | +| | String.valueOf(float f) | +| | String.valueOf(int i) | +| | String.valueOf(long l) | +| | x.contains(CharSequence s) | +| | x.indexOf(String str) | +| | x.indexOf(String str, int fromIndex) | +| |x.replaceAll(String regex, String replacement) | +| |x.split(String regex) | +| |x.split(String regex, int limit) | +| |x.getBytes() | +| |x.getBytes(String charsetName) | diff --git a/docs/get-started/getting-started-with-rapids-accelerator-on-databricks.md b/docs/get-started/getting-started-with-rapids-accelerator-on-databricks.md index be895e6881d..b3b17477f9d 100644 --- a/docs/get-started/getting-started-with-rapids-accelerator-on-databricks.md +++ b/docs/get-started/getting-started-with-rapids-accelerator-on-databricks.md @@ -1,88 +1,88 @@ ---- -layout: page -title: Databricks -nav_order: 3 -parent: Getting-Started ---- - -# Getting started with RAPIDS Accelerator on Databricks -This guide will run through how to set up the RAPIDS Accelerator for Apache Spark 3.0 on Databricks. At the end of this guide, the reader will be able to run a sample Apache Spark application that runs on NVIDIA GPUs on Databricks. - -## Prerequisites - * Apache Spark 3.0 running in DataBricks Runtime 7.0 ML with GPU - * AWS: 7.0 ML (includes Apache Spark 3.0.0, GPU, Scala 2.12) - * Azure: 7.0 ML (GPU, Scala 2.12, Spark 3.0.0) - -The number of GPUs per node dictates the number of Spark executors that can run in that node. - -## Start a Databricks Cluster -Create a Databricks cluster by going to Clusters, then clicking “+ Create Cluster”. Ensure the cluster meets the prerequisites above by configuring it as follows: -1. On AWS, make sure to use 7.0 ML (GPU, Scala 2.12, Spark 3.0.0), or for Azure, choose 7.0 ML (GPU, Scala 2.12, Spark 3.0.0). -2. Under Autopilot Options, disable auto scaling. -3. Choose the number of workers that matches the number of GPUs you want to use. -4. Select a worker type. On AWS, use nodes with 1 GPU each such as `p3.xlarge` or `g4dn.xlarge`. p2 nodes do not meet the architecture requirements for the Spark worker (although they can be used for the driver node). For Azure, choose GPU nodes such as Standard_NC6s_v3. -5. Select the driver type. Generally this can be set to be the same as the worker. -6. Start the cluster - -## Advanced Cluster Configuration - -We will need to create an initialization script for the cluster that installs the RAPIDS jars to the cluster. - -1. To create the initialization script, import the initialization script notebook from the repo [generate-init-script.ipynb](../demo/Databricks/generate-init-script.ipynb) to your workspace. See [Managing Notebooks](https://docs.databricks.com/notebooks/notebooks-manage.html#id2) on how to import a notebook, then open the notebook. -2. Once you are in the notebook, click the “Run All” button. -3. Ensure that the newly created init.sh script is present in the output from cell 2 and that the contents of the script are correct. -4. Go back and edit your cluster to configure it to use the init script. To do this, click the “Clusters” button on the left panel, then select your cluster. -5. Click the “Edit” button, then navigate down to the “Advanced Options” section. Select the “Init Scripts” tab in the advanced options section, and paste the initialization script: `dbfs:/databricks/init_scripts/init.sh`, then click “Add”. - - ![Init Script](../img/initscript.png) - -6. Now select the “Spark” tab, and paste the following config options into the Spark Config section. Change the config values based on the workers you choose. See Apache Spark [configuration](https://spark.apache.org/docs/latest/configuration.html) and RAPIDS Accelerator for Apache Spark [descriptions](../configs) for each config. - - The [`spark.task.resource.gpu.amount`](https://spark.apache.org/docs/latest/configuration.html#scheduling) configuration is defaulted to 1 by Databricks. That means that only 1 task can run on an executor with 1 GPU, which is limiting, especially on the reads and writes from Parquet. Set this to 1/(number of cores per executor) which will allow multiple tasks to run in parallel just like the CPU side. Having the value smaller is fine as well. - - ```bash - spark.plugins com.nvidia.spark.SQLPlugin - spark.sql.parquet.filterPushdown false - spark.rapids.sql.incompatibleOps.enabled true - spark.rapids.memory.pinnedPool.size 2G - spark.task.resource.gpu.amount 0.1 - spark.rapids.sql.concurrentGpuTasks 2 - spark.locality.wait 0s - spark.hadoop.mapreduce.fileoutputcommitter.algorithm.version 2 - spark.executor.extraJavaOptions "-Dai.rapids.cudf.prefer-pinned=true" - ``` - - ![Spark Config](../img/sparkconfig.png) - -7. Once you’ve added the Spark config, click “Confirm and Restart”. -8. Once the cluster comes back up, it is now enabled for GPU-accelerated Spark with RAPIDS and cuDF. - -## Import the GPU Mortgage Example Notebook -Import the example [notebook](../demo/gpu-mortgage_accelerated.ipynb) from the repo into your workspace, then open the notebook. -Modify the first cell to point to your workspace, and download a larger dataset if needed. You can find the links to the datasets at [docs.rapids.ai](https://docs.rapids.ai/datasets/mortgage-data). - -```bash -%sh - -wget http://rapidsai-data.s3-website.us-east-2.amazonaws.com/notebook-mortgage-data/mortgage_2000.tgz -P /Users// - -mkdir -p /dbfs/FileStore/tables/mortgage -mkdir -p /dbfs/FileStore/tables/mortgage_parquet_gpu/perf -mkdir /dbfs/FileStore/tables/mortgage_parquet_gpu/acq -mkdir /dbfs/FileStore/tables/mortgage_parquet_gpu/output - -tar xfvz /Users//mortgage_2000.tgz --directory /dbfs/FileStore/tables/mortgage -``` - -In Cell 3, update the data paths if necessary. The example notebook merges the columns and prepares the data for XGoost training. The temp and final output results are written back to the dbfs. -```bash -orig_perf_path='dbfs:///FileStore/tables/mortgage/perf/*' -orig_acq_path='dbfs:///FileStore/tables/mortgage/acq/*' -tmp_perf_path='dbfs:///FileStore/tables/mortgage_parquet_gpu/perf/' -tmp_acq_path='dbfs:///FileStore/tables/mortgage_parquet_gpu/acq/' -output_path='dbfs:///FileStore/tables/mortgage_parquet_gpu/output/' -``` -Run the notebook by clicking “Run All”. - -## Hints -Spark logs in Databricks are removed upon cluster shutdown. It is possible to save logs in a cloud storage location using Databricks [cluster log delivery](https://docs.databricks.com/clusters/configure.html#cluster-log-delivery-1). Enable this option before starting the cluster to capture the logs. +--- +layout: page +title: Databricks +nav_order: 3 +parent: Getting-Started +--- + +# Getting started with RAPIDS Accelerator on Databricks +This guide will run through how to set up the RAPIDS Accelerator for Apache Spark 3.0 on Databricks. At the end of this guide, the reader will be able to run a sample Apache Spark application that runs on NVIDIA GPUs on Databricks. + +## Prerequisites + * Apache Spark 3.0 running in DataBricks Runtime 7.0 ML with GPU + * AWS: 7.0 ML (includes Apache Spark 3.0.0, GPU, Scala 2.12) + * Azure: 7.0 ML (GPU, Scala 2.12, Spark 3.0.0) + +The number of GPUs per node dictates the number of Spark executors that can run in that node. + +## Start a Databricks Cluster +Create a Databricks cluster by going to Clusters, then clicking “+ Create Cluster”. Ensure the cluster meets the prerequisites above by configuring it as follows: +1. On AWS, make sure to use 7.0 ML (GPU, Scala 2.12, Spark 3.0.0), or for Azure, choose 7.0 ML (GPU, Scala 2.12, Spark 3.0.0). +2. Under Autopilot Options, disable auto scaling. +3. Choose the number of workers that matches the number of GPUs you want to use. +4. Select a worker type. On AWS, use nodes with 1 GPU each such as `p3.xlarge` or `g4dn.xlarge`. p2 nodes do not meet the architecture requirements for the Spark worker (although they can be used for the driver node). For Azure, choose GPU nodes such as Standard_NC6s_v3. +5. Select the driver type. Generally this can be set to be the same as the worker. +6. Start the cluster + +## Advanced Cluster Configuration + +We will need to create an initialization script for the cluster that installs the RAPIDS jars to the cluster. + +1. To create the initialization script, import the initialization script notebook from the repo [generate-init-script.ipynb](../demo/Databricks/generate-init-script.ipynb) to your workspace. See [Managing Notebooks](https://docs.databricks.com/notebooks/notebooks-manage.html#id2) on how to import a notebook, then open the notebook. +2. Once you are in the notebook, click the “Run All” button. +3. Ensure that the newly created init.sh script is present in the output from cell 2 and that the contents of the script are correct. +4. Go back and edit your cluster to configure it to use the init script. To do this, click the “Clusters” button on the left panel, then select your cluster. +5. Click the “Edit” button, then navigate down to the “Advanced Options” section. Select the “Init Scripts” tab in the advanced options section, and paste the initialization script: `dbfs:/databricks/init_scripts/init.sh`, then click “Add”. + + ![Init Script](../img/initscript.png) + +6. Now select the “Spark” tab, and paste the following config options into the Spark Config section. Change the config values based on the workers you choose. See Apache Spark [configuration](https://spark.apache.org/docs/latest/configuration.html) and RAPIDS Accelerator for Apache Spark [descriptions](../configs) for each config. + + The [`spark.task.resource.gpu.amount`](https://spark.apache.org/docs/latest/configuration.html#scheduling) configuration is defaulted to 1 by Databricks. That means that only 1 task can run on an executor with 1 GPU, which is limiting, especially on the reads and writes from Parquet. Set this to 1/(number of cores per executor) which will allow multiple tasks to run in parallel just like the CPU side. Having the value smaller is fine as well. + + ```bash + spark.plugins com.nvidia.spark.SQLPlugin + spark.sql.parquet.filterPushdown false + spark.rapids.sql.incompatibleOps.enabled true + spark.rapids.memory.pinnedPool.size 2G + spark.task.resource.gpu.amount 0.1 + spark.rapids.sql.concurrentGpuTasks 2 + spark.locality.wait 0s + spark.hadoop.mapreduce.fileoutputcommitter.algorithm.version 2 + spark.executor.extraJavaOptions "-Dai.rapids.cudf.prefer-pinned=true" + ``` + + ![Spark Config](../img/sparkconfig.png) + +7. Once you’ve added the Spark config, click “Confirm and Restart”. +8. Once the cluster comes back up, it is now enabled for GPU-accelerated Spark with RAPIDS and cuDF. + +## Import the GPU Mortgage Example Notebook +Import the example [notebook](../demo/gpu-mortgage_accelerated.ipynb) from the repo into your workspace, then open the notebook. +Modify the first cell to point to your workspace, and download a larger dataset if needed. You can find the links to the datasets at [docs.rapids.ai](https://docs.rapids.ai/datasets/mortgage-data). + +```bash +%sh + +wget http://rapidsai-data.s3-website.us-east-2.amazonaws.com/notebook-mortgage-data/mortgage_2000.tgz -P /Users// + +mkdir -p /dbfs/FileStore/tables/mortgage +mkdir -p /dbfs/FileStore/tables/mortgage_parquet_gpu/perf +mkdir /dbfs/FileStore/tables/mortgage_parquet_gpu/acq +mkdir /dbfs/FileStore/tables/mortgage_parquet_gpu/output + +tar xfvz /Users//mortgage_2000.tgz --directory /dbfs/FileStore/tables/mortgage +``` + +In Cell 3, update the data paths if necessary. The example notebook merges the columns and prepares the data for XGoost training. The temp and final output results are written back to the dbfs. +```bash +orig_perf_path='dbfs:///FileStore/tables/mortgage/perf/*' +orig_acq_path='dbfs:///FileStore/tables/mortgage/acq/*' +tmp_perf_path='dbfs:///FileStore/tables/mortgage_parquet_gpu/perf/' +tmp_acq_path='dbfs:///FileStore/tables/mortgage_parquet_gpu/acq/' +output_path='dbfs:///FileStore/tables/mortgage_parquet_gpu/output/' +``` +Run the notebook by clicking “Run All”. + +## Hints +Spark logs in Databricks are removed upon cluster shutdown. It is possible to save logs in a cloud storage location using Databricks [cluster log delivery](https://docs.databricks.com/clusters/configure.html#cluster-log-delivery-1). Enable this option before starting the cluster to capture the logs. diff --git a/udf-compiler/README.md b/udf-compiler/README.md index 692c1043fb0..957b8547da6 100644 --- a/udf-compiler/README.md +++ b/udf-compiler/README.md @@ -20,6 +20,6 @@ export SPARK_HOME=[your spark distribution directory] export JARS=[path to cudf 0.15 jar] $SPARK_HOME/bin/spark-shell \ ---jars $JARS/cudf-0.15-SNAPSHOT-cuda10-1.jar,udf-compiler/target/rapids-4-spark-udf-0.2.0-SNAPSHOT.jar,sql-plugin/target/rapids-4-spark-sql_2.12-0.2.0-SNAPSHOT.jar \ +--jars $JARS/cudf-0.15-SNAPSHOT-cuda10-1.jar,udf-compiler/target/rapids-4-spark-udf_2.12-0.2.0-SNAPSHOT.jar,sql-plugin/target/rapids-4-spark-sql_2.12-0.2.0-SNAPSHOT.jar \ --conf spark.sql.extensions="com.nvidia.spark.SQLPlugin,com.nvidia.spark.udf.Plugin" ``` diff --git a/udf-compiler/src/main/scala/com/nvidia/spark/udf/CFG.scala b/udf-compiler/src/main/scala/com/nvidia/spark/udf/CFG.scala new file mode 100644 index 00000000000..61f46f92cea --- /dev/null +++ b/udf-compiler/src/main/scala/com/nvidia/spark/udf/CFG.scala @@ -0,0 +1,329 @@ +/* + * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.udf + +import CatalystExpressionBuilder.simplify +import javassist.bytecode.{CodeIterator, ConstPool, InstructionPrinter, Opcode} +import scala.annotation.tailrec +import scala.collection.immutable.{HashMap, SortedMap, SortedSet} + +import org.apache.spark.SparkException +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions._ + +/** + * Control Flow Graph (CFG) + * + * This file defines the basic blocks (BB), and a class that can generate a CFG + * by leveraging [[LambdaReflection]]. + */ + +/** + * A Basic Block (BB) is a set of instructions defined by [[instructionTable]], + * where each entry in this table is a mapping of offset to [[Instruction]]. + * + * The case class also provides some helpers, most importantly the [[propagateState]] helper + * which generates a map of [[BB]] to [[State]]. + * + * @param instructionTable + */ +case class BB(instructionTable: SortedMap[Int, Instruction]) extends Logging { + def offset: Int = instructionTable.head._1 + + def last: (Int, Instruction) = instructionTable.last + + def lastOffset: Int = last._1 + + def lastInstruction: Instruction = last._2 + + def propagateState(cfg: CFG, states: Map[BB, State]): Map[BB, State] = { + val state@State(_, _, cond, expr) = states(this) + logDebug(s"[BB.propagateState] propagating condition: ${cond} from state ${state} " + + s"onto states: ${states}") + lastInstruction.opcode match { + case Opcode.IF_ICMPEQ | Opcode.IF_ICMPNE | Opcode.IF_ICMPLT | + Opcode.IF_ICMPGE | Opcode.IF_ICMPGT | Opcode.IF_ICMPLE | + Opcode.IFLT | Opcode.IFLE | Opcode.IFGT | Opcode.IFGE | + Opcode.IFEQ | Opcode.IFNE | Opcode.IFNULL | Opcode.IFNONNULL => { + logTrace(s"[BB.propagateState] lastInstruction: ${lastInstruction.instructionStr}") + + // An if statement has both a false and a true successor + val (0, falseSucc) :: (1, trueSucc) :: Nil = cfg.successor(this) + logTrace(s"[BB.propagateState] falseSucc ${falseSucc} trueSuccc ${trueSucc}") + + // cond is the entry condition into the condition block, and expr is the + // actual condition for IF* (see Instruction.ifOp). + // The entry conditions into the false branch are (cond && !expr) and + // (cond && expr) respectively. + // + // For each successor, create a copy of the current state, and modify it + // with the entry condition for the successor. This ensures that the + // state is propagated to the successors with the correct entry + // conditions. + // + val falseState = state.copy(cond = simplify(And(cond, Not(expr.get)))) + val trueState = state.copy(cond = simplify(And(cond, expr.get))) + + logDebug(s"[BB.propagateState] States before: ${states}") + + // Each successor may already have the state populated if it has + // multiple predecessors. + // Update the states by merging the new state with the existing state. + val newStates = (states + + (falseSucc -> falseState.merge(states.get(falseSucc))) + + (trueSucc -> trueState.merge(states.get(trueSucc)))) + logDebug(s"[BB.propagateState] States after: ${newStates}") + newStates + } + case Opcode.TABLESWITCH | Opcode.LOOKUPSWITCH => + val table = cfg.successor(this).init + val (-1, defaultSucc) = cfg.successor(this).last + // Update the entry conditions of non-default successors based on the + // match keys, and combine them to create the entry condition of the + // default successor. + val (defaultCondition, newStates) = ( + table.foldLeft[(Expression, Map[BB, State])]((Literal.TrueLiteral, states)) { + case ((cond: Expression, currentStates: Map[BB, State]), + (matchKey: Int, succ: BB)) => + val newState = state.copy(cond = simplify(EqualTo(expr.get, Literal(matchKey)))) + (And(Not(EqualTo(expr.get, Literal(matchKey))), cond), + currentStates + (succ -> newState.merge(currentStates.get(succ)))) + } + ) + // Update the entry condition of the default successor. + val defaultState = state.copy(cond = simplify(defaultCondition)) + newStates + (defaultSucc -> defaultState.merge(newStates.get(defaultSucc))) + case Opcode.IRETURN | Opcode.LRETURN | Opcode.FRETURN | Opcode.DRETURN | + Opcode.ARETURN | Opcode.RETURN => states + case _ => + val (0, successor) :: Nil = cfg.successor(this) + // The condition, stack and locals from the current BB state need to be + // propagated to its successor. + states + (successor -> state.merge(states.get(successor))) + } + } +} + +/** + * The Control Flow Graph object. + * + * @param basicBlocks : the basic blocks for this CFG + * @param predecessor : given a [[BB]] this maps the [[BB]]s to its predecessors + * @param successor : given a [[BB]] this maps the [[BB]]s to its successors. + * Each element in the succssor list also has an Int value + * which is used as a case for tableswitch and lookupswitch. + */ +case class CFG(basicBlocks: List[BB], + predecessor: Map[BB, List[BB]], + successor: Map[BB, List[(Int, BB)]]) + +/** + * Companion object to generate a [[CFG]] instance given a [[LambdaReflection]] + */ +object CFG { + /** + * Iterate through the code to find out the basic blocks + */ + def apply(lambdaReflection: LambdaReflection): CFG = { + val codeIterator = lambdaReflection.codeIterator + + // labels: targets of branching instructions (offset) + // edges: connection between branch instruction offset, and target offsets (successors) + // if ifeq then there would be a true and a false successor + // if return there would be no successors (likely) + // goto has 1 successors + codeIterator.begin() + val (labels, edges) = collectLabelsAndEdges(codeIterator, lambdaReflection.constPool) + + codeIterator.begin() // rewind + val instructionTable = createInstructionTable(codeIterator, lambdaReflection.constPool) + + val (basicBlocks, offsetToBB) = createBasicBlocks(labels, instructionTable) + + val (predecessor, successor) = connectBasicBlocks(basicBlocks, offsetToBB, edges) + + CFG(basicBlocks, predecessor, successor) + } + + @tailrec + private def collectLabelsAndEdges(codeIterator: CodeIterator, + constPool: ConstPool, + labels: SortedSet[Int] = SortedSet(), + edges: SortedMap[Int, List[(Int, Int)]] = SortedMap()) + : (SortedSet[Int], SortedMap[Int, List[(Int, Int)]]) = { + if (codeIterator.hasNext) { + val offset: Int = codeIterator.next + val nextOffset: Int = codeIterator.lookAhead + val opcode: Int = codeIterator.byteAt(offset) + // here we are looking for branching instructions + opcode match { + case Opcode.IF_ICMPEQ | Opcode.IF_ICMPNE | Opcode.IF_ICMPLT | + Opcode.IF_ICMPGE | Opcode.IF_ICMPGT | Opcode.IF_ICMPLE | + Opcode.IFEQ | Opcode.IFNE | Opcode.IFLT | Opcode.IFGE | + Opcode.IFGT | Opcode.IFLE | Opcode.IFNULL | Opcode.IFNONNULL => + // an if statement has two other offsets, false and true branches. + + // the false offset is the next offset, per the definition of if + val falseOffset = nextOffset + + // in jvm, the if ops are followed by two bytes, which are to be + // used together (s16bitAt does this for us) only for the success case of the if + val trueOffset = offset + codeIterator.s16bitAt(offset + 1) + + // keep iterating, having added the false and true offsets to the labels, + // and having added the edges (if offset -> List(false offset, true offset)) + collectLabelsAndEdges( + codeIterator, constPool, + labels + falseOffset + trueOffset, + edges + (offset -> List((0, falseOffset), (1, trueOffset)))) + case Opcode.TABLESWITCH => + val defaultOffset = (offset + 4) / 4 * 4 + val default = (-1, offset + codeIterator.s32bitAt(defaultOffset)) + val lowOffset = defaultOffset + 4 + val low = codeIterator.s32bitAt(lowOffset) + val highOffset = lowOffset + 4 + val high = codeIterator.s32bitAt(highOffset) + val tableOffset = highOffset + 4 + val table = List.tabulate(high - low + 1) { i => + (low + i, offset + codeIterator.s32bitAt(tableOffset + i * 4)) + } :+ default + collectLabelsAndEdges( + codeIterator, constPool, + labels ++ table.unzip._2, + edges + (offset -> table)) + case Opcode.LOOKUPSWITCH => + val defaultOffset = (offset + 4) / 4 * 4 + val default = (-1, offset + codeIterator.s32bitAt(defaultOffset)) + val npairsOffset = defaultOffset + 4 + val npairs = codeIterator.s32bitAt(npairsOffset) + val tableOffset = npairsOffset + 4 + val table = List.tabulate(npairs) { i => + (codeIterator.s32bitAt(tableOffset + i * 8), + offset + codeIterator.s32bitAt(tableOffset + i * 8 + 4)) + } :+ default + collectLabelsAndEdges( + codeIterator, constPool, + labels ++ table.unzip._2, + edges + (offset -> table)) + case Opcode.GOTO | Opcode.GOTO_W => + // goto statements have a single address target, we must go there + val getOffset = if (opcode == Opcode.GOTO) { + codeIterator.s16bitAt(_) + } else { + codeIterator.s32bitAt(_) + } + val labelOffset = offset + getOffset(offset + 1) + collectLabelsAndEdges( + codeIterator, constPool, + labels + labelOffset, + edges + (offset -> List((0, labelOffset)))) + case Opcode.IF_ACMPEQ | Opcode.IF_ACMPNE | + Opcode.JSR | Opcode.JSR_W | Opcode.RET => + val instructionStr = InstructionPrinter.instructionString(codeIterator, offset, constPool) + throw new SparkException("Unsupported instruction: " + instructionStr) + case _ => collectLabelsAndEdges(codeIterator, constPool, labels, edges) + } + } else { + // base case + (labels, edges) + } + } + + @tailrec + private def createInstructionTable(codeIterator: CodeIterator, constPool: ConstPool, + instructionTable: SortedMap[Int, Instruction] = SortedMap()) + : SortedMap[Int, Instruction] = { + if (codeIterator.hasNext) { + val offset = codeIterator.next + val instructionStr = InstructionPrinter.instructionString(codeIterator, offset, constPool) + val instruction = Instruction(codeIterator, offset, instructionStr) + createInstructionTable(codeIterator, constPool, + instructionTable + (offset -> instruction)) + } else { + instructionTable + } + } + + @tailrec + private def createBasicBlocks(labels: SortedSet[Int], + instructionTable: SortedMap[Int, Instruction], + basicBlocks: List[BB] = List(), + offsetToBB: Map[Int, BB] = HashMap()): (List[BB], Map[Int, BB]) = { + if (labels.isEmpty) { + val instructions = instructionTable + val bb = BB(instructions) + ((bb +: basicBlocks).reverse, + instructions.foldLeft(offsetToBB) { case (offsetToBB, (offset, _)) => + offsetToBB + (offset -> bb) + }) + } else { + // get instuctions prior to the first label (branch) we are looking at + val (instructions, instructionsForOtherBBs) = instructionTable.span(_._1 < labels.head) + + // BB is a node in the CFG, BB -> BB connects via branch + val bb = BB(instructions) // put the instructions that belong together into a BB + + // create more BB's with the rest of the instructions post branch + createBasicBlocks( + labels.tail, + instructionsForOtherBBs, + // With immutable linked list, prepend is faster than append. + // basicBlocks is an immutable linked list. + bb +: basicBlocks, + instructions.foldLeft(offsetToBB) { case (offsetToBB, (offset, _)) => + offsetToBB + (offset -> bb) + }) + } + } + + @tailrec + private def connectBasicBlocks(basicBlocks: List[BB], + offsetToBB: Map[Int, BB], + edges: SortedMap[Int, List[(Int, Int)]], + predecessor: Map[BB, List[BB]] = Map().withDefaultValue(Nil), + successor: Map[BB, List[(Int, BB)]] = Map().withDefaultValue(Nil)) + : (Map[BB, List[BB]], Map[BB, List[(Int, BB)]]) = { + if (basicBlocks.isEmpty) { + (predecessor, successor) + } else { + // Connect the first basic block in basicBlocks (src) with its predecssors + // and successors. + val src :: rest = basicBlocks + // Get the destination basic blocks (dst) of the edges that connect from + // src. + val dst = edges.getOrElse(src.lastOffset, + if (rest.isEmpty) { + List() + } else { + List((0, rest.head.offset)) + }).map { case (k, v) => (k, offsetToBB(v)) } + // Recursively call connectBasicBlocks with the rest of basicBlocks. + connectBasicBlocks( + rest, + offsetToBB, + edges, + //For each basic block, l, in dst, update predecessor map by prepending + //src to predecessor(l). + dst.foldLeft(predecessor) { case (p: Map[BB, List[BB]], (_, l)) => { + p + (l -> (src :: p(l))) + } + }, + // Add src -> dst to successor map. + successor + (src -> dst)) + } + } +} diff --git a/udf-compiler/src/main/scala/com/nvidia/spark/udf/CatalystExpressionBuilder.scala b/udf-compiler/src/main/scala/com/nvidia/spark/udf/CatalystExpressionBuilder.scala new file mode 100644 index 00000000000..8a6634473a2 --- /dev/null +++ b/udf-compiler/src/main/scala/com/nvidia/spark/udf/CatalystExpressionBuilder.scala @@ -0,0 +1,428 @@ +/* + * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.udf + +import scala.annotation.tailrec + +import javassist.CtClass + +import org.apache.spark.SparkException +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ + +/** + * CatalystExpressionBuilder + * + * This compiles a scala lambda expression into a catalyst expression. + * + * Here are the high-level steps: + * + * 1) Use SerializedLambda and javaassist to get a reflection based interface to `function` + * this is done in [[LambdaReflection]] + * + * 2) Obtain the Control Flow Graph (CFG) using the reflection interface obtained above. + * + * 3) Get catalyst `Expressions` based on the basic blocks [[BB]] obtained in the CFG + * and simplify before replacing in the Spark Logical Plan. + * + * @param function the original Scala UDF provided by the user + */ +case class CatalystExpressionBuilder(private val function: AnyRef) extends Logging { + final private val lambdaReflection: LambdaReflection = LambdaReflection(function) + + final private val cfg = CFG(lambdaReflection) + + /** + * [[compile]]: Entry point for [[CatalystExpressionBuilder]]. + * + * With this function we: + * + * 1) Create a starting [[State]], which ultimately is used to keep track of + * the locals, stack, condition, expression. + * + * 2) Pick out the head Basic Block (BB) from the Control Flow Graph (CFG) + * NOTE: that this picks the head element, and then recurses + * + * 3) Feed head BB to the start node + * + * @param children a sequence of catalyst arguments to the udf. + * @return the compiled expression, optionally + */ + def compile(children: Seq[Expression]): Option[Expression] = { + + // create starting state, this will be: + // State([children expressions], [empty stack], cond = true, expr = None) + val entryState = State.makeStartingState(lambdaReflection, children) + + // pick first of the Basic Blocks, and start recursing + val entryBlock = cfg.basicBlocks.head + + logDebug(s"[CatalystExpressionBuilder] Attempting to compile: ${function}, " + + s"with children: ${children}, " + s"entry block: ${entryBlock}, and " + + s"entry state: ${entryState}") + + // start recursing + val compiled = doCompile(List(entryBlock), Map(entryBlock -> entryState)).map { e => + if (lambdaReflection.ret == CtClass.booleanType) { + // JVM bytecode returns an integer value when the return type is + // boolean, hence the cast. + CatalystExpressionBuilder.simplify(Cast(e, BooleanType)) + } else { + e + } + } + + if (compiled == None) { + logDebug(s"[CatalystExpressionBuilder] failed to compile") + } else { + logDebug(s"[CatalystExpressionBuilder] compiled expression: ${compiled.get.toString}") + } + + compiled + } + + /** + * doCompile: using a starting basic block and state, produce a new [[State]] based on the + * basic block's [[Instruction]] table, and ultimately recurse through the successor(s) + * of the basic block we are currently visiting in the [[CFG]], using Depth-First Traversal. + * + * 1) [[compile]] calls [[doCompile]] with the head [[BB]] of [[cfg]] and its [[State]]. + * + * 2) [[doCompile]] will fold the currently visiting [[BB]]'s instruction table, + * by making a new [[State]], that consumes the prior [[State]] + * + * 3) With the new states, we call the visiting block's [[BB.propagateState]], + * and propagate its state to update successors' states. + * + * 4) If the block's last instruction is a return, this part of the graph has reached its end, + * return whatever expression was accrued in [[State.expr]] + * + * else, recurse + * + * @param worklist stack for depth-first traversal. + * @param states map between [[BB]]s and their [[State]]s. Each [[State]] + * keeps track of locals, evaluation stack, and condition for the [[BB]]. + * @param pending [[BB]]s that are ready to be added to worklist + * once all their predecessors have been visited (have a count of 0). + * @param visited the set of [[BB]] we have seen so far. It is used to make + * sure each [[BB]] is visited only once. + * @return the compiled expression, optionally + */ + @tailrec + private def doCompile(worklist: List[BB], + states: Map[BB, State], + pending: Map[BB, Int] = cfg.predecessor.mapValues(_.size), + visited: Set[BB] = Set()): Option[Expression] = { + /** + * Pick the first block, and store the rest of the list in [[rest]]. + * + * 1) Initially, [[worklist] is [[CFG.head]] :: nil + * 2) As we recurse, [[worklist]] gets new [[BB]]s when the all of its predecessors are + * visited. + * 3) The head [[BB]] ([[basicBlock]]), then goes through the compilation process: + * i) [[State]] is obtained (at the beginning, there's a seed [[State]] added in [[compile]] + * ii) after each iteration, new [[State]] is created for [[basicBlock]]. This is the + * first step where we take + * javaassist Opcode foreach [[Instruction]] in the [[BB]]'s instruction table, and turn + * it into [[State]] + * objects with: locals, stack, condition, and an evolving catalyst expression. + * ii) the state is then propagated: + * + */ + + val basicBlock :: rest = worklist + + // find the state associated with this BB + val state: State = states(basicBlock) + + logTrace(s"States for basic block ${basicBlock} => ${state}") + + /** + * Iterate through the instruction table for the BB: + * Using [[state]] as the starting value, apply the instruction [[Instruction.apply]] + * to obtain a new [[State]]. This new state is passed back to [[Instruction.apply]], + * as foldLeft makes its way through the BB's Instruction Table. + */ + val it: Map[Int, Instruction] = basicBlock.instructionTable + + val newState: State = it.foldLeft(state) { (st: State, i: (Int, Instruction)) => + val instruction: Instruction = i._2 + instruction.makeState(lambdaReflection, basicBlock, st) + } + + val sb = new StringBuilder() + sb.append(s"[CatalystExpressionBuilder.doCompile] Basic Block ${basicBlock}") + + // when you have branching expressions, we need to look at both the true and false expressions + // if (x > 0) 1 else 0 + val newStates = basicBlock.propagateState(cfg, states + (basicBlock -> newState)) + + // This is testing whether the last instruction of the basic block is a return. + // A basic block can have other branching instructions as the last instruction, + // otherwise. + if (basicBlock.lastInstruction.isReturn) { + newStates(basicBlock).expr + } else { + // account for this block in visited + val newVisited = visited + basicBlock + + /** + * 1) For the currently vising [[BB]], get the successors from the [[CFG]]. + * 2) The foldLeft this list, into a list of successors ([[readySucc]]), and + * a map of predecessor [[newPending]] counts + * + * Among the successors of the current [[BB]], find the ones that are + * ready for traversal. They are added to [[readySucc]] and removed from + * [[pending]] to create [[newPending]]. + * + * A succesor is ready for traversal, if all of its predecessors have been visited. + */ + val (readySucc: List[BB], newPending: Map[BB, Int]) = + cfg.successor(basicBlock).foldLeft((List[BB](), pending)) { + case (x@(remaining: List[BB], currentPending: Map[BB, Int]), (_, successor)) => + if (newVisited(successor)) { + // This successor has already been visited through another path. + // Do not update readySucc or newPending. + x + } else { + // [[currentPending]] is used to make sure that a [[BB]] is visited after + // all its predecessors have been visited. + val count = currentPending(successor) - 1 + if (count > 0) { + // overwrite decremented successor's pending count + (remaining, // readySucc + currentPending + (successor -> count)) // newPending + } else { + // count <= 0 + // add successor to the remaining, remove from pending. + (successor :: remaining, // readySucc + currentPending - successor) // newPending + } + } + } + + if (rest.isEmpty && readySucc.isEmpty && newPending.nonEmpty) { + // We allow a node to be visited only after all its predecessors + // are visited, but if a node is the entry to a loop, all its + // predecessors cannot be visited unless this node is visited. + // This case results in an empty worklist with non-empty new pending + // list. + throw new SparkException("Unsupported control flow: loop") + } + + doCompile( + readySucc ::: rest, + newStates, + newPending, + newVisited) + } + } +} + +/** + * CatalystExpressionBuilder helper object, contains a function that is used to + * simplify a directly translated catalyst expression (from bytecode) into something simpler + * that the remaining catalyst optimizations can handle. + */ +object CatalystExpressionBuilder extends Logging { + /** simplify: given a raw converted catalyst expression, attempt to match patterns to simplify + * before handing it over to catalyst optimizers (the LogicalPlan does this later). + * + * It is called from [[State.merge]], from itself, and from [[BB.propagateState]]. + * + * @param expr + * @return + */ + @tailrec + final def simplify(expr: Expression): Expression = { + def simplifyExpr(expr: Expression): Expression = { + val res = expr match { + case And(Literal.TrueLiteral, c) => simplifyExpr(c) + case And(c, Literal.TrueLiteral) => simplifyExpr(c) + case And(Literal.FalseLiteral, c) => Literal.FalseLiteral + case And(c1@LessThan(s1, Literal(v1, t1)), + c2@LessThan(s2, Literal(v2, t2))) if s1 == s2 && t1 == t2 => { + t1 match { + case IntegerType => + if (v1.asInstanceOf[Int] < v2.asInstanceOf[Int]) { + c1 + } else { + c2 + } + case LongType => + if (v1.asInstanceOf[Long] < v2.asInstanceOf[Long]) { + c1 + } else { + c2 + } + case _ => expr + } + } + case And(c1@LessThanOrEqual(s1, Literal(v1, t1)), + c2@LessThanOrEqual(s2, Literal(v2, t2))) if s1 == s2 && t1 == t2 => { + t1 match { + case IntegerType => + if (v1.asInstanceOf[Int] < v2.asInstanceOf[Int]) { + c1 + } else { + c2 + } + case LongType => + if (v1.asInstanceOf[Long] < v2.asInstanceOf[Long]) { + c1 + } else { + c2 + } + case _ => expr + } + } + case And(c1@LessThanOrEqual(s1, Literal(v1, t1)), + c2@LessThan(s2, Literal(v2, t2))) if s1 == s2 && t1 == t2 => { + t1 match { + case IntegerType => + if (v1.asInstanceOf[Int] < v2.asInstanceOf[Int]) { + c1 + } else { + c2 + } + case LongType => + if (v1.asInstanceOf[Long] < v2.asInstanceOf[Long]) { + c1 + } else { + c2 + } + case _ => expr + } + } + case And(c1@GreaterThan(s1, Literal(v1, t1)), + c2@GreaterThan(s2, Literal(v2, t2))) if s1 == s2 && t1 == t2 => { + t1 match { + case IntegerType => + if (v1.asInstanceOf[Int] > v2.asInstanceOf[Int]) { + c1 + } else { + c2 + } + case LongType => + if (v1.asInstanceOf[Long] > v2.asInstanceOf[Long]) { + c1 + } else { + c2 + } + case _ => expr + } + } + case And(c1@GreaterThan(s1, Literal(v1, t1)), + c2@GreaterThanOrEqual(s2, Literal(v2, t2))) if s1 == s2 && t1 == t2 => { + t1 match { + case IntegerType => + if (v1.asInstanceOf[Int] >= v2.asInstanceOf[Int]) { + c1 + } else { + c2 + } + case LongType => + if (v1.asInstanceOf[Long] >= v2.asInstanceOf[Long]) { + c1 + } else { + c2 + } + case _ => expr + } + } + case And(c1, c2) => And(simplifyExpr(c1), simplifyExpr(c2)) + case Or(Literal.TrueLiteral, c) => Literal.TrueLiteral + case Or(Literal.FalseLiteral, c) => simplifyExpr(c) + case Or(c, Literal.FalseLiteral) => simplifyExpr(c) + case Or(c1@GreaterThan(s1, Literal(v1, t1)), + c2@GreaterThanOrEqual(s2, Literal(v2, t2))) if s1 == s2 && t1 == t2 => { + t1 match { + case IntegerType => + if (v1.asInstanceOf[Int] < v2.asInstanceOf[Int]) { + c1 + } else { + c2 + } + case LongType => + if (v1.asInstanceOf[Long] < v2.asInstanceOf[Long]) { + c1 + } else { + c2 + } + case _ => expr + } + } + case Or(c1, c2) => Or(simplifyExpr(c1), simplifyExpr(c2)) + case Not(Literal.TrueLiteral) => Literal.FalseLiteral + case Not(Literal.FalseLiteral) => Literal.TrueLiteral + case Not(LessThan(c1, c2)) => GreaterThanOrEqual(c1, c2) + case Not(LessThanOrEqual(c1, c2)) => GreaterThan(c1, c2) + case Not(GreaterThan(c1, c2)) => LessThanOrEqual(c1, c2) + case Not(GreaterThanOrEqual(c1, c2)) => LessThan(c1, c2) + case EqualTo(Literal(v1, _), Literal(v2, _)) => + if (v1 == v2) Literal.TrueLiteral else Literal.FalseLiteral + case LessThan(If(c1, + Literal(1, _), + If(c2, + Literal(-1, _), + Literal(0, _))), + Literal(0, _)) => simplifyExpr(And(Not(c1), c2)) + case LessThanOrEqual(If(c1, + Literal(1, _), + If(c2, + Literal(-1, _), + Literal(0, _))), + Literal(0, _)) => simplifyExpr(Not(c1)) + case GreaterThan(If(c1, + Literal(1, _), + If(c2, + Literal(-1, _), + Literal(0, _))), + Literal(0, _)) => c1 + case GreaterThanOrEqual(If(c1, + Literal(1, _), + If(c2, + Literal(-1, _), + Literal(0, _))), + Literal(0, _)) => simplifyExpr(Or(c1, Not(c2))) + case EqualTo(If(c1, + Literal(1, _), + If(c2, + Literal(-1, _), + Literal(0, _))), + Literal(0, _)) => simplifyExpr(And(Not(c1), Not(c2))) + case If(c, t, f) if t == f => t + // JVMachine encodes boolean array components using 1 to represent true + // and 0 to represent false (see + // https://docs.oracle.com/javase/specs/jvms/se8/html/jvms-2.html#jvms-2.3.4). + case Cast(Literal(0, _), BooleanType, _) => Literal.FalseLiteral + case Cast(Literal(1, _), BooleanType, _) => Literal.TrueLiteral + case Cast(If(c, t, f), BooleanType, tz) => + simplifyExpr(If(simplifyExpr(c), + simplifyExpr(Cast(t, BooleanType, tz)), + simplifyExpr(Cast(f, BooleanType, tz)))) + case _ => expr + } + logDebug(s"[CatalystExpressionBuilder] simplify: ${expr} ==> ${res}") + res + } + + val simplifiedExpr = simplifyExpr(expr) + if (simplifiedExpr == expr) simplifiedExpr else simplify(simplifiedExpr) + } +} diff --git a/udf-compiler/src/main/scala/com/nvidia/spark/udf/GpuScalaUDF.scala b/udf-compiler/src/main/scala/com/nvidia/spark/udf/GpuScalaUDF.scala new file mode 100644 index 00000000000..4344772f337 --- /dev/null +++ b/udf-compiler/src/main/scala/com/nvidia/spark/udf/GpuScalaUDF.scala @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.udf + +import com.nvidia.spark.rapids.RapidsConf + +import org.apache.spark.SparkException +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.types.DataType + +case class GpuScalaUDFLogical(udf: ScalaUDF) extends Expression with Logging { + override def nullable: Boolean = udf.nullable + + override def eval(input: InternalRow): Any = { + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + null + } + + override def dataType: DataType = udf.dataType + + override def children: Seq[Expression] = udf.children + + def compile(isTestEnabled: Boolean): Expression = { + // call the compiler + try { + val expr = CatalystExpressionBuilder(udf.function).compile(udf.children) + if (expr.isDefined) { + expr.get + } else { + udf + } + } catch { + case e: SparkException => + logDebug("UDF compilation failure: " + e) + if (isTestEnabled) { + throw e + } + udf + } + } +} diff --git a/udf-compiler/src/main/scala/com/nvidia/spark/udf/Instruction.scala b/udf-compiler/src/main/scala/com/nvidia/spark/udf/Instruction.scala new file mode 100644 index 00000000000..02adb5d381d --- /dev/null +++ b/udf-compiler/src/main/scala/com/nvidia/spark/udf/Instruction.scala @@ -0,0 +1,549 @@ +/* + * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.udf + +import CatalystExpressionBuilder.simplify +import java.nio.charset.Charset + +import javassist.bytecode.{CodeIterator, Opcode} +import org.apache.spark.SparkException + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ + + +private object Repr { + + abstract class CompilerInternal(name: String) extends Expression { + override def dataType: DataType = { + throw new SparkException(s"Compiler internal representation of " + + s"${name} cannot be evaluated") + } + + override def doGenCode(ctx: codegen.CodegenContext, ev: codegen.ExprCode): codegen.ExprCode = { + throw new SparkException(s"Cannot generate code for compiler internal " + + s"representation of ${name}") + } + + override def eval(input: org.apache.spark.sql.catalyst.InternalRow): Any = { + throw new SparkException(s"Compiler internal representation of " + + s"${name} cannot be evaluated") + } + + override def nullable: Boolean = { + throw new SparkException(s"Compiler internal representation of " + + s"${name} cannot be evaluated") + } + + override def children: Seq[Expression] = { + throw new SparkException(s"Compiler internal representation of " + + s"${name} cannot be evaluated") + } + } + + // Internal representation of java.lang.StringBuilder. + case class StringBuilder() extends CompilerInternal("java.lang.StringBuilder") { + def invoke(methodName: String, args: List[Expression]): Expression = { + methodName match { + case "StringBuilder" => this + case "append" => string = Concat(string :: args) + this + case "toString" => string + case _ => + throw new SparkException(s"Unsupported StringBuilder op ${methodName}") + } + } + + var string: Expression = Literal.default(StringType) + } + +} + +/** + * + * @param opcode + * @param operand + */ +case class Instruction(opcode: Int, operand: Int, instructionStr: String) extends Logging { + def makeState(lambdaReflection: LambdaReflection, basicBlock: BB, state: State): State = { + val st = opcode match { + case Opcode.ALOAD_0 | Opcode.DLOAD_0 | Opcode.FLOAD_0 | + Opcode.ILOAD_0 | Opcode.LLOAD_0 => load(state, 0) + case Opcode.ALOAD_1 | Opcode.DLOAD_1 | Opcode.FLOAD_1 | + Opcode.ILOAD_1 | Opcode.LLOAD_1 => load(state, 1) + case Opcode.ALOAD_2 | Opcode.DLOAD_2 | Opcode.FLOAD_2 | + Opcode.ILOAD_2 | Opcode.LLOAD_2 => load(state, 2) + case Opcode.ALOAD_3 | Opcode.DLOAD_3 | Opcode.FLOAD_3 | + Opcode.ILOAD_3 | Opcode.LLOAD_3 => load(state, 3) + case Opcode.ALOAD | Opcode.DLOAD | Opcode.FLOAD | + Opcode.ILOAD | Opcode.LLOAD => load(state, operand) + case Opcode.ASTORE_0 | Opcode.DSTORE_0 | Opcode.FSTORE_0 | + Opcode.ISTORE_0 | Opcode.LSTORE_0 => store(state, 0) + case Opcode.ASTORE_1 | Opcode.DSTORE_1 | Opcode.FSTORE_1 | + Opcode.ISTORE_1 | Opcode.LSTORE_1 => store(state, 1) + case Opcode.ASTORE_2 | Opcode.DSTORE_2 | Opcode.FSTORE_2 | + Opcode.ISTORE_2 | Opcode.LSTORE_2 => store(state, 2) + case Opcode.ASTORE_3 | Opcode.DSTORE_3 | Opcode.FSTORE_3 | + Opcode.ISTORE_3 | Opcode.LSTORE_3 => store(state, 3) + case Opcode.ASTORE | Opcode.DSTORE | Opcode.FSTORE | + Opcode.ISTORE | Opcode.LSTORE => store(state, operand) + case Opcode.DCONST_0 | Opcode.DCONST_1 => + const(state, (opcode - Opcode.DCONST_0).asInstanceOf[Double]) + case Opcode.FCONST_0 | Opcode.FCONST_1 | Opcode.FCONST_2 => + const(state, (opcode - Opcode.FCONST_0).asInstanceOf[Float]) + case Opcode.BIPUSH | Opcode.SIPUSH => + const(state, operand) + case Opcode.ICONST_M1 | + Opcode.ICONST_0 | Opcode.ICONST_1 | Opcode.ICONST_2 | + Opcode.ICONST_3 | Opcode.ICONST_4 | Opcode.ICONST_5 => + const(state, (opcode - Opcode.ICONST_0).asInstanceOf[Int]) + case Opcode.LCONST_0 | Opcode.LCONST_1 => + const(state, (opcode - Opcode.LCONST_0).asInstanceOf[Long]) + case Opcode.DADD | Opcode.FADD | Opcode.IADD | Opcode.LADD => binary(state, Add(_, _)) + case Opcode.DSUB | Opcode.FSUB | Opcode.ISUB | Opcode.LSUB => binary(state, Subtract(_, _)) + case Opcode.DMUL | Opcode.FMUL | Opcode.IMUL | Opcode.LMUL => binary(state, Multiply(_, _)) + case Opcode.DDIV | Opcode.FDIV => binary(state, Divide(_, _)) + case Opcode.IDIV | Opcode.LDIV => binary(state, IntegralDivide(_, _)) + case Opcode.DREM | Opcode.FREM | Opcode.IREM | Opcode.LREM => binary(state, Remainder(_, _)) + case Opcode.IAND | Opcode.LAND => binary(state, BitwiseAnd(_, _)) + case Opcode.IOR | Opcode.LOR => binary(state, BitwiseOr(_, _)) + case Opcode.IXOR | Opcode.LXOR => binary(state, BitwiseXor(_, _)) + case Opcode.ISHL | Opcode.LSHL => binary(state, ShiftLeft(_, _)) + case Opcode.ISHR | Opcode.LSHR => binary(state, ShiftRight(_, _)) + case Opcode.IUSHR | Opcode.LUSHR => binary(state, ShiftRightUnsigned(_, _)) + case Opcode.DNEG | Opcode.FNEG | Opcode.INEG | Opcode.LNEG => neg(state) + case Opcode.DCMPL | Opcode.FCMPL => cmp(state, -1) + case Opcode.DCMPG | Opcode.FCMPG => cmp(state, 1) + case Opcode.LCMP => cmp(state) + case Opcode.LDC | Opcode.LDC_W | Opcode.LDC2_W => ldc(lambdaReflection, state) + case Opcode.DUP => dup(state) + case Opcode.GETSTATIC => getstatic(state) + case Opcode.NEW => newObj(lambdaReflection, state) + // Cast instructions + case Opcode.I2B => cast(state, ByteType) + case Opcode.I2C => + throw new SparkException("Opcode.I2C unsupported: no corresponding Catalyst expression") + case Opcode.F2D | Opcode.I2D | Opcode.L2D => cast(state, DoubleType) + case Opcode.D2F | Opcode.I2F | Opcode.L2F => cast(state, FloatType) + case Opcode.D2I | Opcode.F2I | Opcode.L2I => cast(state, IntegerType) + case Opcode.D2L | Opcode.F2L | Opcode.I2L => cast(state, LongType) + case Opcode.I2S => cast(state, ShortType) + // Branching instructions + // if_acmp isn't supported. + case Opcode.IF_ICMPEQ => ifCmp(state, (x, y) => simplify(EqualTo(x, y))) + case Opcode.IF_ICMPNE => ifCmp(state, (x, y) => simplify(Not(EqualTo(x, y)))) + case Opcode.IF_ICMPLT => ifCmp(state, (x, y) => simplify(LessThan(x, y))) + case Opcode.IF_ICMPGE => ifCmp(state, (x, y) => simplify(GreaterThanOrEqual(x, y))) + case Opcode.IF_ICMPGT => ifCmp(state, (x, y) => simplify(GreaterThan(x, y))) + case Opcode.IF_ICMPLE => ifCmp(state, (x, y) => simplify(LessThanOrEqual(x, y))) + case Opcode.IFLT => ifOp(state, x => simplify(LessThan(x, Literal(0)))) + case Opcode.IFLE => ifOp(state, x => simplify(LessThanOrEqual(x, Literal(0)))) + case Opcode.IFGT => ifOp(state, x => simplify(GreaterThan(x, Literal(0)))) + case Opcode.IFGE => ifOp(state, x => simplify(GreaterThanOrEqual(x, Literal(0)))) + case Opcode.IFEQ => ifOp(state, x => simplify(EqualTo(x, Literal(0)))) + case Opcode.IFNE => ifOp(state, x => simplify(Not(EqualTo(x, Literal(0))))) + case Opcode.IFNULL => ifOp(state, x => simplify(IsNull(x))) + case Opcode.IFNONNULL => ifOp(state, x => simplify(IsNotNull(x))) + case Opcode.TABLESWITCH | Opcode.LOOKUPSWITCH => switch(state) + case Opcode.GOTO => state + case Opcode.IRETURN | Opcode.LRETURN | Opcode.FRETURN | Opcode.DRETURN | + Opcode.ARETURN | Opcode.RETURN => + state.copy(expr = Some(state.stack.head)) + // Call instructions + case Opcode.INVOKESTATIC => + invoke(lambdaReflection, state, + (stack, n) => { + val (args, rest) = stack.splitAt(n) + (args.reverse, rest) + }) + case Opcode.INVOKEVIRTUAL | Opcode.INVOKESPECIAL => + invoke(lambdaReflection, state, + (stack, n) => { + val (args, rest) = stack.splitAt(n + 1) + (args.reverse, rest) + }) + case _ => throw new SparkException("Unsupported instruction: " + instructionStr) + } + logDebug(s"[Instruction] ${instructionStr} got new state: ${st} from state: ${state}") + st + } + + def isReturn: Boolean = opcode match { + case Opcode.IRETURN | Opcode.LRETURN | Opcode.FRETURN | Opcode.DRETURN | + Opcode.ARETURN | Opcode.RETURN => true + case _ => false + } + + // + // Handle instructions + // + private def load(state: State, localsIndex: Int): State = { + val State(locals, stack, cond, expr) = state + State(locals, locals(localsIndex) :: stack, cond, expr) + } + + private def store(state: State, localsIndex: Int): State = { + val State(locals, top :: rest, cond, expr) = state + State(locals.updated(localsIndex, top), rest, cond, expr) + } + + private def const(state: State, value: Any): State = { + val State(locals, stack, cond, expr) = state + State(locals, Literal(value) :: stack, cond, expr) + } + + private def binary(state: State, op: (Expression, Expression) => Expression): State = { + val State(locals, op2 :: op1 :: rest, cond, expr) = state + State(locals, op(op1, op2) :: rest, cond, expr) + } + + private def neg(state: State): State = { + val State(locals, top :: rest, cond, expr) = state + State(locals, UnaryMinus(top) :: rest, cond, expr) + } + + private def ldc(lambdaReflection: LambdaReflection, state: State): State = { + val State(locals, stack, cond, expr) = state + val constant = Literal(lambdaReflection.lookupConstant(operand)) + State(locals, constant :: stack, cond, expr) + } + + private def dup(state: State): State = { + val State(locals, top :: rest, cond, expr) = state + State(locals, top :: top :: rest, cond, expr) + } + + private def newObj(lambdaReflection: LambdaReflection, + state: State): State = { + val typeName = lambdaReflection.lookupClassName(operand) + if (typeName.equals("java.lang.StringBuilder")) { + val State(locals, stack, cond, expr) = state + State(locals, Repr.StringBuilder() :: stack, cond, expr) + } else { + throw new SparkException("Unsupported type for new:" + typeName) + } + } + + private def getstatic(state: State): State = { + val State(locals, stack, cond, expr) = state + State(locals, Literal(operand) :: stack, cond, expr) + } + + private def cmp(state: State, default: Int): State = { + val State(locals, op2 :: op1 :: rest, cond, expr) = state + val conditional = + If(Or(IsNaN(op1), IsNaN(op2)), + Literal(default), + If(GreaterThan(op1, op2), + Literal(1), + If(LessThan(op1, op2), + Literal(-1), + Literal(0)))) + State(locals, conditional :: rest, cond, expr) + } + + private def cmp(state: State): State = { + val State(locals, op2 :: op1 :: rest, cond, expr) = state + val conditional = + If(GreaterThan(op1, op2), + Literal(1), + If(LessThan(op1, op2), + Literal(-1), + Literal(0))) + State(locals, conditional :: rest, cond, expr) + } + + private def cast( + state: State, + dataType: DataType): State = { + val State(locals, top :: rest, cond, expr) = state + State(locals, Cast(top, dataType) :: rest, cond, expr) + } + + private def ifCmp(state: State, + predicate: (Expression, Expression) => Expression): State = { + val State(locals, op2 :: op1 :: rest, cond, expr) = state + State(locals, rest, cond, Some(predicate(op1, op2))) + } + + private def ifOp( + state: State, + predicate: Expression => Expression): State = { + val State(locals, top :: rest, cond, expr) = state + State(locals, rest, cond, Some(predicate(top))) + } + + private def switch(state: State): State = { + val State(locals, top :: rest, cond, expr) = state + State(locals, rest, cond, Some(top)) + } + + private def invoke(lambdaReflection: LambdaReflection, state: State, + getArgs: (List[Expression], Int) => + (List[Expression], List[Expression])): State = { + val State(locals, stack, cond, expr) = state + val method = lambdaReflection.lookupBehavior(operand) + val declaringClassName = method.getDeclaringClass.getName + val paramTypes = method.getParameterTypes + val (args, rest) = getArgs(stack, paramTypes.length) + // We don't support arbitrary calls. + // We support only some math and string methods. + if (declaringClassName.equals("scala.math.package$")) { + State(locals, + mathOp(lambdaReflection, method.getName, args) :: rest, + cond, + expr) + } else if (declaringClassName.equals("java.lang.String")) { + State(locals, stringOp(method.getName, args) :: rest, cond, expr) + } else if (declaringClassName.equals("java.lang.StringBuilder")) { + if (!args.head.isInstanceOf[Repr.StringBuilder]) { + throw new SparkException("Internal error with StringBuilder") + } + val retval = args.head.asInstanceOf[Repr.StringBuilder] + .invoke(method.getName, args.tail) + State(locals, retval :: rest, cond, expr) + } else { + // Other functions + throw new SparkException("Unsupported instruction: " + Opcode.INVOKEVIRTUAL) + } + } + + def mathOp(lambdaReflection: LambdaReflection, + methodName: String, args: List[Expression]): Expression = { + // Math unary functions + if (args.length != 2) { + throw new SparkException( + s"Unary math operation expects 1 argument and an objref, but " + + s"instead got ${args.length - 1} arguments and an objref.") + } + // Make sure that the objref is scala.math.package$. + args.head match { + case Literal(index, IntegerType) => + if (!lambdaReflection.lookupField(index.asInstanceOf[Int]) + .getType.getName.equals("scala.math.package$")) { + throw new SparkException("Unsupported math function objref: " + args.head) + } + case _ => + throw new SparkException("Unsupported math function objref: " + args.head) + } + // Translate to Catalyst + val arg = args.last + methodName match { + case "abs" => Abs(arg) + case "acos" => Acos(arg) + case "asin" => Asin(arg) + case "atan" => Atan(arg) + case "cos" => Cos(arg) + case "cosh" => Cosh(arg) + case "sin" => Sin(arg) + case "tan" => Tan(arg) + case "tanh" => Tanh(arg) + case "ceil" => Ceil(arg) + case "floor" => Floor(arg) + case "exp" => Exp(arg) + case "log" => Log(arg) + case "log10" => Log10(arg) + case "sqrt" => Sqrt(arg) + case _ => throw new SparkException("Unsupported math function: " + methodName) + } + } + + def stringOp(methodName: String, args: List[Expression]): Expression = { + def checkArgs(expectedTypes: List[DataType]): Unit = { + if (args.length != expectedTypes.length) { + throw new SparkException( + s"String.${methodName} operation expects ${expectedTypes.length} " + + s"argument(s), including an objref, but instead got ${args.length} " + + s"argument(s)") + } + args.view.zip(expectedTypes.view).foreach { case (arg, expectedType) => + if (arg.dataType != expectedType) { + throw new SparkException(s"${arg.dataType} argument found for " + + s"String.${methodName} where " + + s"${expectedType} argument is expected.") + } + } + } + + methodName match { + case "concat" => + checkArgs(List(StringType, StringType)) + Concat(args) + case "contains" => + checkArgs(List(StringType, StringType)) + Contains(args.head, args.last) + case "endsWith" => + checkArgs(List(StringType, StringType)) + EndsWith(args.head, args.last) + case "equals" => + checkArgs(List(StringType, StringType)) + Cast(EqualNullSafe(args.head, args.last), IntegerType) + case "equalsIgnoreCase" => + checkArgs(List(StringType, StringType)) + Cast(EqualNullSafe(Upper(args.head), Upper(args.last)), IntegerType) + case "isEmpty" => + checkArgs(List(StringType)) + Cast(EqualTo(Length(args.head), Literal(0)), IntegerType) + case "length" => + checkArgs(List(StringType)) + Length(args.head) + case "startsWith" => + checkArgs(List(StringType, StringType)) + StartsWith(args.head, args.last) + case "toLowerCase" => + checkArgs(List(StringType)) + Lower(args.head) + case "toUpperCase" => + checkArgs(List(StringType)) + Upper(args.head) + case "trim" => + checkArgs(List(StringType)) + StringTrim(args.head) + case "replace" => + if (args.length != 3) { + throw new SparkException( + s"String.${methodName} operation expects 3 argument(s), " + + s"including an objref, but instead got ${args.length} " + + s"argument(s)") + } + if (args(1).dataType == StringType && + args(2).dataType == StringType) { + StringReplace(args(0), args(1), args(2)) + } else if (args(1).dataType == IntegerType && + args(2).dataType == IntegerType) { + StringReplace(args(0), Chr(args(1)), Chr(args(2))) + } else { + throw new SparkException(s"Unsupported argument type for " + + s"String.${methodName}: " + + s"${args(0).dataType}, " + + s"${args(1).dataType}, and " + + s"${args(2).dataType}") + } + case "substring" => + checkArgs(StringType :: List.fill(args.length - 1)(IntegerType)) + Substring(args(0), + Add(args(1), Literal(1)), + Subtract(if (args.length == 3) args(2) else Length(args(0)), + args(1))) + case "valueOf" => + val supportedArgs = List(BooleanType, DoubleType, FloatType, + IntegerType, LongType) + if (args.length != 1) { + throw new SparkException( + s"String.${methodName} operation expects 1 " + + s"argument(s), including an objref, but instead got ${args.length} " + + s"argument(s)") + } + if (!supportedArgs.contains(args.head.dataType)) { + throw new SparkException(s"Unsupported argument type for " + + s"String.${methodName}: " + + s"${args.head.dataType}") + } + Cast(args.head, StringType) + case "indexOf" => + if (args.length == 2) { + if (args(1).dataType == StringType) { + Subtract(StringInstr(args(0), args(1)), Literal(1)) + } else { + throw new SparkException(s"Unsupported argument type for " + + s"String.${methodName}: " + + s"${args(0).dataType} and " + + s"${args(1).dataType}") + } + } else if (args.length == 3) { + if (args(1).dataType == StringType && + args(2).dataType == IntegerType) { + Subtract(StringLocate(args(1), args(0), Add(args(2), Literal(1))), + Literal(1)) + } else { + throw new SparkException(s"Unsupported argument type for " + + s"String.${methodName}: " + + s"${args(0).dataType}, " + + s"${args(1).dataType}, and " + + s"${args(2).dataType}") + } + } else { + throw new SparkException( + s"String.${methodName} operation expects 2 or 3 argument(s), " + + s"including an objref, but instead got ${args.length} " + + s"argument(s)") + } + case "replaceAll" => + checkArgs(List(StringType, StringType, StringType)) + RegExpReplace(args(0), args(1), args(2)) + case "split" => + if (args.length == 2) { + checkArgs(List(StringType, StringType)) + StringSplit(args(0), args(1), Literal(-1)) + } else if (args.length == 3) { + checkArgs(List(StringType, StringType, IntegerType)) + StringSplit(args(0), args(1), args(2)) + } else { + throw new SparkException( + s"String.${methodName} operation expects 2 or 3 argument(s), " + + s"including an objref, but instead got ${args.length} " + + s"argument(s)") + } + case "getBytes" => + if (args.length == 1) { + checkArgs(List(StringType)) + Encode(args.head, Literal(Charset.defaultCharset.toString)) + } else if (args.length == 2) { + checkArgs(List(StringType, StringType)) + Encode(args.head, args.last) + } else { + throw new SparkException( + s"String.${methodName} operation expects 1 or 2 argument(s), " + + s"including an objref, but instead got ${args.length} " + + s"argument(s)") + } + case _ => + throw new SparkException(s"Unsupported string function: " + + s"String.${methodName}") + } + } +} + +/** + * Ultimately, every opcode will have to be covered here. + */ +object Instruction { + def apply(codeIterator: CodeIterator, offset: Int, instructionStr: String): Instruction = { + val opcode: Int = codeIterator.byteAt(offset) + val operand: Int = opcode match { + case Opcode.ALOAD | Opcode.DLOAD | Opcode.FLOAD | + Opcode.ILOAD | Opcode.LLOAD | Opcode.LDC => + codeIterator.byteAt(offset + 1) + case Opcode.BIPUSH => + codeIterator.signedByteAt(offset + 1) + case Opcode.LDC_W | Opcode.LDC2_W | Opcode.NEW | + Opcode.INVOKESTATIC | Opcode.INVOKEVIRTUAL | Opcode.INVOKEINTERFACE | + Opcode.INVOKESPECIAL | Opcode.GETSTATIC => + codeIterator.u16bitAt(offset + 1) + case Opcode.GOTO | + Opcode.IFEQ | Opcode.IFNE | Opcode.IFLT | + Opcode.IFGE | Opcode.IFGT | Opcode.IFLE | + Opcode.IFNULL | Opcode.IFNONNULL | + Opcode.SIPUSH => + codeIterator.s16bitAt(offset + 1) + case _ => 0 + } + Instruction(opcode, operand, instructionStr) + } +} diff --git a/udf-compiler/src/main/scala/com/nvidia/spark/udf/LambdaReflection.scala b/udf-compiler/src/main/scala/com/nvidia/spark/udf/LambdaReflection.scala new file mode 100644 index 00000000000..bc0939d8361 --- /dev/null +++ b/udf-compiler/src/main/scala/com/nvidia/spark/udf/LambdaReflection.scala @@ -0,0 +1,126 @@ +/* + * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.udf + +import java.lang.invoke.SerializedLambda + +import javassist.{ClassClassPath, ClassPool, CtBehavior, CtClass, CtField} +import javassist.bytecode.{CodeIterator, ConstPool, Descriptor} + +import org.apache.spark.SparkException + +// +// Reflection using SerializedLambda and javassist. +// +// Provides the interface the class and the method that implements the body of the lambda +// used by the rest of the compiler. +// +case class LambdaReflection(private val classPool: ClassPool, + private val serializedLambda: SerializedLambda) { + def lookupConstant(constPoolIndex: Int): Any = { + constPool.getTag(constPoolIndex) match { + case ConstPool.CONST_Integer => constPool.getIntegerInfo(constPoolIndex) + case ConstPool.CONST_Long => constPool.getLongInfo(constPoolIndex) + case ConstPool.CONST_Float => constPool.getFloatInfo(constPoolIndex) + case ConstPool.CONST_Double => constPool.getDoubleInfo(constPoolIndex) + case ConstPool.CONST_String => constPool.getStringInfo(constPoolIndex) + case _ => throw new SparkException("Unsupported constant") + } + } + + def lookupField(constPoolIndex: Int): CtField = { + if (constPool.getTag(constPoolIndex) != ConstPool.CONST_Fieldref) { + throw new SparkException("Unexpected index for field reference") + } + val fieldName = constPool.getFieldrefName(constPoolIndex) + val descriptor = constPool.getFieldrefType(constPoolIndex) + val className = constPool.getFieldrefClassName(constPoolIndex) + classPool.getCtClass(className).getField(fieldName, descriptor) + } + + def lookupBehavior(constPoolIndex: Int): CtBehavior = { + if (constPool.getTag(constPoolIndex) != ConstPool.CONST_Methodref) { + throw new SparkException("Unexpected index for method reference") + } + val methodName = constPool.getMethodrefName(constPoolIndex) + val descriptor = constPool.getMethodrefType(constPoolIndex) + val className = constPool.getMethodrefClassName(constPoolIndex) + val params = Descriptor.getParameterTypes(descriptor, classPool) + if (constPool.isConstructor(className, constPoolIndex) == 0) { + classPool.getCtClass(className).getDeclaredMethod(methodName, params) + } else { + classPool.getCtClass(className).getDeclaredConstructor(params) + } + } + + def lookupClassName(constPoolIndex: Int): String = { + if (constPool.getTag(constPoolIndex) != ConstPool.CONST_Class) { + throw new SparkException("Unexpected index for class") + } + constPool.getClassInfo(constPoolIndex) + } + + // Get the CtClass object for the class that capture the lambda. + private val ctClass = { + val name = serializedLambda.getCapturingClass.replace('/', '.') + val loader = Thread.currentThread().getContextClassLoader + // scalastyle:off classforname + val classForName = Class.forName(name, true, loader) + // scalastyle:on classforname + classPool.insertClassPath(new ClassClassPath(classForName)) + classPool.getCtClass(name) + } + + // Get the CtMethod object for the method that implements the lambda body. + private val ctMethod = { + val lambdaImplName = serializedLambda.getImplMethodName + ctClass.getDeclaredMethod(lambdaImplName.stripSuffix("$adapted")) + } + + private val methodInfo = ctMethod.getMethodInfo + + val constPool = methodInfo.getConstPool + + private val codeAttribute = methodInfo.getCodeAttribute + + lazy val codeIterator: CodeIterator = codeAttribute.iterator + + lazy val parameters: Array[CtClass] = ctMethod.getParameterTypes + + lazy val ret: CtClass = ctMethod.getReturnType + + lazy val maxLocals: Int = codeAttribute.getMaxLocals +} + +object LambdaReflection { + def apply(function: AnyRef): LambdaReflection = { + // writeReplace is supposed to return an object of SerializedLambda from + // the function class (See + // https://docs.oracle.com/javase/8/docs/api/java/lang/invoke/SerializedLambda.html). + // With the object of SerializedLambda, we can get our hands on the class + // and the method that implement the lambda body. + val functionClass = function.getClass + val writeReplace = functionClass.getDeclaredMethod("writeReplace") + writeReplace.setAccessible(true) + val serializedLambda = writeReplace.invoke(function) + .asInstanceOf[SerializedLambda] + + val classPool = new ClassPool(true) + LambdaReflection(classPool, serializedLambda) + } +} + diff --git a/udf-compiler/src/main/scala/com/nvidia/spark/udf/Plugin.scala b/udf-compiler/src/main/scala/com/nvidia/spark/udf/Plugin.scala index 451f1a4f54d..679d126b95c 100644 --- a/udf-compiler/src/main/scala/com/nvidia/spark/udf/Plugin.scala +++ b/udf-compiler/src/main/scala/com/nvidia/spark/udf/Plugin.scala @@ -16,29 +16,64 @@ package com.nvidia.spark.udf +import ai.rapids.cudf.{NvtxColor, NvtxRange} +import com.nvidia.spark.rapids.RapidsConf + import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSessionExtensions import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression, ScalaUDF} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule -import com.nvidia.spark.rapids.RapidsConf - class Plugin extends Function1[SparkSessionExtensions, Unit] with Logging { override def apply(extensions: SparkSessionExtensions): Unit = { logWarning("Installing rapids UDF compiler extensions to Spark. The compiler is disabled" + - s" by default. To enable it, set `${RapidsConf.UDF_COMPILER_ENABLED}` to true") + s" by default. To enable it, set `${RapidsConf.UDF_COMPILER_ENABLED}` to true") extensions.injectResolutionRule(_ => LogicalPlanRules()) } } case class LogicalPlanRules() extends Rule[LogicalPlan] with Logging { def replacePartialFunc(plan: LogicalPlan): PartialFunction[Expression, Expression] = { - case d: Expression => attemptToReplaceExpression(plan, d) + case d: Expression => { + val nvtx = new NvtxRange("replace UDF", NvtxColor.BLUE) + try { + attemptToReplaceExpression(plan, d) + } finally { + nvtx.close() + } + } } def attemptToReplaceExpression(plan: LogicalPlan, exp: Expression): Expression = { - exp + val conf = new RapidsConf(plan.conf) + // iterating over NamedExpression + exp match { + case f: ScalaUDF => // found a ScalaUDF + GpuScalaUDFLogical(f).compile(conf.isTestEnabled) + case _ => + if (exp == null) { + exp + } else { + try { + if (exp.children != null && !exp.children.exists(x => x == null)) { + exp.withNewChildren(exp.children.map(c => { + if (c != null && c.isInstanceOf[Expression]) { + attemptToReplaceExpression(plan, c) + } else { + c + } + })) + } else { + exp + } + } catch { + case npe: NullPointerException => { + exp + } + } + } + } } override def apply(plan: LogicalPlan): LogicalPlan = { @@ -47,7 +82,7 @@ case class LogicalPlanRules() extends Rule[LogicalPlan] with Logging { plan match { case project: Project => Project(project.projectList.map(e => attemptToReplaceExpression(plan, e)) - .asInstanceOf[Seq[NamedExpression]], project.child) + .asInstanceOf[Seq[NamedExpression]], project.child) case x => { x.transformExpressions(replacePartialFunc(plan)) } diff --git a/udf-compiler/src/main/scala/com/nvidia/spark/udf/State.scala b/udf-compiler/src/main/scala/com/nvidia/spark/udf/State.scala new file mode 100644 index 00000000000..df49b80a664 --- /dev/null +++ b/udf-compiler/src/main/scala/com/nvidia/spark/udf/State.scala @@ -0,0 +1,139 @@ +/* + * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.udf + +import CatalystExpressionBuilder.simplify +import javassist.CtClass + +import org.apache.spark.sql.catalyst.expressions.{Expression, If, Literal, Or} + +/** + * State is used as the main representation of block state, as we walk the bytecode. + * + * Given a set of instructions, we will use State variables to track what happens to the stack. + * + * The final State generated is later used to simplify expressions. + * + * Example 1: + * { + * return 0 + * } + * + * This is in java byte code: + * iconst 0 + * ireturn + * + * iconst 0 => pushes 0 to stack + * ireturn => pops and returns + * + * Example 2: + * { + * return 2 + 2 + 1 + * } + * + * 1) State(locals, empty stack, no condition, expr?) // expr ==it is no expression + * + * iconst 2 + * NOTE: 2 is literal here so is 1 + * 2) State(locals, 2::Nil, no condition, expr is still empty) + * + * iconst 2 + * 3) State(locals, 2::2::Nil, no condiiton...) + * + * iadd (pop 2 and 2 + push 4 into stack) + * 4) State(locals, Add(2, 2)::Nil, no condition, expr is still empoty) + * + * iconst 1 + * 5) State(locals, 1::Add(2,2)::Nil, ..) + * + * iadd (pop 1 and 4 + push 5 into sack) + * 6) Add(1, Add(2,2)) :: Nil + * + * ireturn (pop 5 from stack) + * 7) return Add... + * + * State == Add + * stack == lhs/rhs + * + * @param locals + * @param stack + * @param cond + * @param expr + */ +case class State(locals: Array[Expression], + stack: List[Expression] = List(), + cond: Expression = Literal.TrueLiteral, + expr: Option[Expression] = None) { + + def merge(that: Option[State]): State = { + that.fold(this) { s => + val combine: ((Expression, Expression)) => + Expression = { + case (l1, l2) => simplify(If(cond, l1, l2)) + } + // At the end of the compliation, the expression at the top of stack is + // returned, which must have all the conditionals embedded, if the + // bytecode had any conditional. For this reason, we apply combine to + // each element in the stack and locals. + s.copy(locals = locals.zip(s.locals).map(combine), + stack = stack.zip(s.stack).map(combine), + // The combined state is for the cases s.cond is met or cond + // is met, hence or. + cond = simplify(Or(s.cond, cond))) + } + } + + override def toString: String = { + s"State(locals=[${printExpressions(locals)}], stack=[${printExpressions(stack)}], " + + s"cond=[${printExpressions(Seq(cond))}], expr=[${expr.map(e => e.toString())}])" + } + + private def printExpressions(expressions: Iterable[Expression]): String = { + if (expressions == null) { + "NULL" + } else { + expressions.map(e => if (e == null) { + "NULL" + } else { + e.toString() + }).mkString(", ") + } + } +} + +object State { + def makeStartingState(lambdaReflection: LambdaReflection, + children: Seq[Expression]): State = { + val max = lambdaReflection.maxLocals + val params: Seq[(CtClass, Expression)] = lambdaReflection.parameters.view.zip(children) + val (locals, _) = params.foldLeft((new Array[Expression](max), 0)) { (l, p) => + val (locals: Array[Expression], index: Int) = l + val (param: CtClass, argExp: Expression) = p + + val newIndex = if (param == CtClass.doubleType || param == CtClass.longType) { + // Long and Double occupies two slots in the local variable array. + // See https://docs.oracle.com/javase/specs/jvms/se8/html/jvms-2.html#jvms-2.6.1 + index + 2 + } else { + index + 1 + } + + (locals.updated(index, argExp), newIndex) + } + State(locals) + } +}