Skip to content

Commit

Permalink
Fix a test error for DB13.3 [databricks] (#10816)
Browse files Browse the repository at this point in the history
Fix #10797

This PR uses a new config relevant to arrow batch slicing for the arrow python runner pick, and applies the pick rule of the arrow python runner to GpuAggreagteInPandasExec in addition to GpuFlatMapGroupInPandasExec.

---------

Signed-off-by: Firestarman <firestarmanllc@gmail.com>
  • Loading branch information
firestarman authored May 15, 2024
1 parent 4b3366f commit 95067fe
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 35 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION.
* Copyright (c) 2020-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -31,8 +31,8 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.rapids.execution.python.shims.GpuArrowPythonRunner
import org.apache.spark.sql.rapids.shims.{ArrowUtilsShim, DataTypeUtilsShim}
import org.apache.spark.sql.rapids.execution.python.shims.GpuGroupedPythonRunnerFactory
import org.apache.spark.sql.rapids.shims.DataTypeUtilsShim
import org.apache.spark.sql.types.{DataType, StructField, StructType}
import org.apache.spark.sql.vectorized.ColumnarBatch

Expand Down Expand Up @@ -109,8 +109,6 @@ case class GpuAggregateInPandasExec(
val (mNumInputRows, mNumInputBatches, mNumOutputRows, mNumOutputBatches) = commonGpuMetrics()

lazy val isPythonOnGpuEnabled = GpuPythonHelper.isPythonOnGpuEnabled(conf)
val sessionLocalTimeZone = conf.sessionLocalTimeZone
val pythonRunnerConf = ArrowUtilsShim.getPythonRunnerConfMap(conf)
val childOutput = child.output
val resultExprs = resultExpressions

Expand Down Expand Up @@ -204,27 +202,22 @@ case class GpuAggregateInPandasExec(
}
}

val runnerFactory = GpuGroupedPythonRunnerFactory(conf, pyFuncs, argOffsets,
aggInputSchema, DataTypeUtilsShim.fromAttributes(pyOutAttributes),
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF)

// Third, sends to Python to execute the aggregate and returns the result.
if (pyInputIter.hasNext) {
// Launch Python workers only when the data is not empty.
val pyRunner = new GpuArrowPythonRunner(
pyFuncs,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
argOffsets,
aggInputSchema,
sessionLocalTimeZone,
pythonRunnerConf,
// The whole group data should be written in a single call, so here is unlimited
Int.MaxValue,
DataTypeUtilsShim.fromAttributes(pyOutAttributes))

val pyRunner = runnerFactory.getRunner()
val pyOutputIterator = pyRunner.compute(pyInputIter, context.partitionId(), context)

val combinedAttrs = gpuGroupingExpressions.map(_.toAttribute) ++ pyOutAttributes
val resultRefs = GpuBindReferences.bindGpuReferences(resultExprs, combinedAttrs)
// Gets the combined batch for each group and projects for the output.
new CombiningIterator(batchProducer.getBatchQueue, pyOutputIterator, pyRunner,
mNumOutputRows, mNumOutputBatches).map { combinedBatch =>
new CombiningIterator(batchProducer.getBatchQueue, pyOutputIterator,
pyRunner.asInstanceOf[GpuArrowOutput], mNumOutputRows,
mNumOutputBatches).map { combinedBatch =>
withResource(combinedBatch) { batch =>
GpuProjectExec.project(batch, resultRefs)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION.
* Copyright (c) 2020-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -21,7 +21,7 @@ import com.nvidia.spark.rapids.python.PythonWorkerSemaphore
import com.nvidia.spark.rapids.shims.ShimUnaryExecNode

import org.apache.spark.TaskContext
import org.apache.spark.api.python.ChainedPythonFunctions
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning}
Expand Down Expand Up @@ -123,7 +123,8 @@ case class GpuFlatMapGroupsInPandasExec(
resolveArgOffsets(child, groupingAttributes)

val runnerFactory = GpuGroupedPythonRunnerFactory(conf, chainedFunc, Array(argOffsets),
DataTypeUtilsShim.fromAttributes(dedupAttrs), pythonOutputSchema)
DataTypeUtilsShim.fromAttributes(dedupAttrs), pythonOutputSchema,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)

// Start processing. Map grouped batches to ArrowPythonRunner results.
child.executeColumnar().mapPartitionsInternal { inputIter =>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -39,7 +39,7 @@
spark-rapids-shim-json-lines ***/
package org.apache.spark.sql.rapids.execution.python.shims

import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
import org.apache.spark.api.python.ChainedPythonFunctions
import org.apache.spark.sql.rapids.shims.ArrowUtilsShim
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.vectorized.ColumnarBatch
Expand All @@ -49,14 +49,15 @@ case class GpuGroupedPythonRunnerFactory(
chainedFunc: Seq[ChainedPythonFunctions],
argOffsets: Array[Array[Int]],
dedupAttrs: StructType,
pythonOutputSchema: StructType) {
pythonOutputSchema: StructType,
evalType: Int) {
val sessionLocalTimeZone = conf.sessionLocalTimeZone
val pythonRunnerConf = ArrowUtilsShim.getPythonRunnerConfMap(conf)

def getRunner(): GpuBasePythonRunner[ColumnarBatch] = {
new GpuArrowPythonRunner(
chainedFunc,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
evalType,
argOffsets,
dedupAttrs,
sessionLocalTimeZone,
Expand All @@ -65,4 +66,4 @@ case class GpuGroupedPythonRunnerFactory(
Int.MaxValue,
pythonOutputSchema)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
spark-rapids-shim-json-lines ***/
package org.apache.spark.sql.rapids.execution.python.shims

import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
import org.apache.spark.api.python.ChainedPythonFunctions
import org.apache.spark.sql.rapids.shims.ArrowUtilsShim
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.vectorized.ColumnarBatch
Expand All @@ -30,7 +30,8 @@ case class GpuGroupedPythonRunnerFactory(
chainedFunc: Seq[ChainedPythonFunctions],
argOffsets: Array[Array[Int]],
dedupAttrs: StructType,
pythonOutputSchema: StructType) {
pythonOutputSchema: StructType,
evalType: Int) {
// Configs from DB runtime
val maxBytes = conf.pandasZeroConfConversionGroupbyApplyMaxBytesPerSlice
val zeroConfEnabled = conf.pandasZeroConfConversionGroupbyApplyEnabled
Expand All @@ -41,7 +42,7 @@ case class GpuGroupedPythonRunnerFactory(
if (zeroConfEnabled && maxBytes > 0L) {
new GpuGroupUDFArrowPythonRunner(
chainedFunc,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
evalType,
argOffsets,
dedupAttrs,
sessionLocalTimeZone,
Expand All @@ -52,7 +53,7 @@ case class GpuGroupedPythonRunnerFactory(
} else {
new GpuArrowPythonRunner(
chainedFunc,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
evalType,
argOffsets,
dedupAttrs,
sessionLocalTimeZone,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,25 @@ import org.apache.spark.sql.rapids.shims.ArrowUtilsShim
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch

//TODO is this needed? we already have a similar version in spark330db
case class GpuGroupedPythonRunnerFactory(
conf: org.apache.spark.sql.internal.SQLConf,
chainedFunc: Seq[ChainedPythonFunctions],
argOffsets: Array[Array[Int]],
dedupAttrs: StructType,
pythonOutputSchema: StructType) {
pythonOutputSchema: StructType,
evalType: Int) {
// Configs from DB runtime
val maxBytes = conf.pandasZeroConfConversionGroupbyApplyMaxBytesPerSlice
val zeroConfEnabled = conf.pandasZeroConfConversionGroupbyApplyEnabled
val isArrowBatchSlicingEnabled = conf.pythonArrowBatchSlicingEnabled
val sessionLocalTimeZone = conf.sessionLocalTimeZone
val pythonRunnerConf = ArrowUtilsShim.getPythonRunnerConfMap(conf)

def getRunner(): GpuBasePythonRunner[ColumnarBatch] = {
if (zeroConfEnabled && maxBytes > 0L) {
if (isArrowBatchSlicingEnabled || (zeroConfEnabled && maxBytes > 0L)) {
new GpuGroupUDFArrowPythonRunner(
chainedFunc,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
evalType,
argOffsets,
dedupAttrs,
sessionLocalTimeZone,
Expand All @@ -52,7 +53,7 @@ case class GpuGroupedPythonRunnerFactory(
} else {
new GpuArrowPythonRunner(
chainedFunc,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
evalType,
argOffsets,
dedupAttrs,
sessionLocalTimeZone,
Expand Down

0 comments on commit 95067fe

Please sign in to comment.