diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuAggregateInPandasExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuAggregateInPandasExec.scala index 3f3f2803f5c..bc2f30dff2f 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuAggregateInPandasExec.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuAggregateInPandasExec.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,8 +31,8 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.rapids.execution.python.shims.GpuArrowPythonRunner -import org.apache.spark.sql.rapids.shims.{ArrowUtilsShim, DataTypeUtilsShim} +import org.apache.spark.sql.rapids.execution.python.shims.GpuGroupedPythonRunnerFactory +import org.apache.spark.sql.rapids.shims.DataTypeUtilsShim import org.apache.spark.sql.types.{DataType, StructField, StructType} import org.apache.spark.sql.vectorized.ColumnarBatch @@ -109,8 +109,6 @@ case class GpuAggregateInPandasExec( val (mNumInputRows, mNumInputBatches, mNumOutputRows, mNumOutputBatches) = commonGpuMetrics() lazy val isPythonOnGpuEnabled = GpuPythonHelper.isPythonOnGpuEnabled(conf) - val sessionLocalTimeZone = conf.sessionLocalTimeZone - val pythonRunnerConf = ArrowUtilsShim.getPythonRunnerConfMap(conf) val childOutput = child.output val resultExprs = resultExpressions @@ -204,27 +202,22 @@ case class GpuAggregateInPandasExec( } } + val runnerFactory = GpuGroupedPythonRunnerFactory(conf, pyFuncs, argOffsets, + aggInputSchema, DataTypeUtilsShim.fromAttributes(pyOutAttributes), + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF) + // Third, sends to Python to execute the aggregate and returns the result. if (pyInputIter.hasNext) { // Launch Python workers only when the data is not empty. - val pyRunner = new GpuArrowPythonRunner( - pyFuncs, - PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, - argOffsets, - aggInputSchema, - sessionLocalTimeZone, - pythonRunnerConf, - // The whole group data should be written in a single call, so here is unlimited - Int.MaxValue, - DataTypeUtilsShim.fromAttributes(pyOutAttributes)) - + val pyRunner = runnerFactory.getRunner() val pyOutputIterator = pyRunner.compute(pyInputIter, context.partitionId(), context) val combinedAttrs = gpuGroupingExpressions.map(_.toAttribute) ++ pyOutAttributes val resultRefs = GpuBindReferences.bindGpuReferences(resultExprs, combinedAttrs) // Gets the combined batch for each group and projects for the output. - new CombiningIterator(batchProducer.getBatchQueue, pyOutputIterator, pyRunner, - mNumOutputRows, mNumOutputBatches).map { combinedBatch => + new CombiningIterator(batchProducer.getBatchQueue, pyOutputIterator, + pyRunner.asInstanceOf[GpuArrowOutput], mNumOutputRows, + mNumOutputBatches).map { combinedBatch => withResource(combinedBatch) { batch => GpuProjectExec.project(batch, resultRefs) } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuFlatMapGroupsInPandasExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuFlatMapGroupsInPandasExec.scala index 6c2f716583f..4a24a449b24 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuFlatMapGroupsInPandasExec.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuFlatMapGroupsInPandasExec.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ import com.nvidia.spark.rapids.python.PythonWorkerSemaphore import com.nvidia.spark.rapids.shims.ShimUnaryExecNode import org.apache.spark.TaskContext -import org.apache.spark.api.python.ChainedPythonFunctions +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} @@ -123,7 +123,8 @@ case class GpuFlatMapGroupsInPandasExec( resolveArgOffsets(child, groupingAttributes) val runnerFactory = GpuGroupedPythonRunnerFactory(conf, chainedFunc, Array(argOffsets), - DataTypeUtilsShim.fromAttributes(dedupAttrs), pythonOutputSchema) + DataTypeUtilsShim.fromAttributes(dedupAttrs), pythonOutputSchema, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) // Start processing. Map grouped batches to ArrowPythonRunner results. child.executeColumnar().mapPartitionsInternal { inputIter => diff --git a/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuGroupedPythonRunnerFactory.scala b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuGroupedPythonRunnerFactory.scala index 4186effcf84..a35cba87a16 100644 --- a/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuGroupedPythonRunnerFactory.scala +++ b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuGroupedPythonRunnerFactory.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -39,7 +39,7 @@ spark-rapids-shim-json-lines ***/ package org.apache.spark.sql.rapids.execution.python.shims -import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} +import org.apache.spark.api.python.ChainedPythonFunctions import org.apache.spark.sql.rapids.shims.ArrowUtilsShim import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.ColumnarBatch @@ -49,14 +49,15 @@ case class GpuGroupedPythonRunnerFactory( chainedFunc: Seq[ChainedPythonFunctions], argOffsets: Array[Array[Int]], dedupAttrs: StructType, - pythonOutputSchema: StructType) { + pythonOutputSchema: StructType, + evalType: Int) { val sessionLocalTimeZone = conf.sessionLocalTimeZone val pythonRunnerConf = ArrowUtilsShim.getPythonRunnerConfMap(conf) def getRunner(): GpuBasePythonRunner[ColumnarBatch] = { new GpuArrowPythonRunner( chainedFunc, - PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + evalType, argOffsets, dedupAttrs, sessionLocalTimeZone, @@ -65,4 +66,4 @@ case class GpuGroupedPythonRunnerFactory( Int.MaxValue, pythonOutputSchema) } -} \ No newline at end of file +} diff --git a/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuGroupedPythonRunnerFactory.scala b/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuGroupedPythonRunnerFactory.scala index c97bf1abd3e..451de0a2527 100644 --- a/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuGroupedPythonRunnerFactory.scala +++ b/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuGroupedPythonRunnerFactory.scala @@ -20,7 +20,7 @@ spark-rapids-shim-json-lines ***/ package org.apache.spark.sql.rapids.execution.python.shims -import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} +import org.apache.spark.api.python.ChainedPythonFunctions import org.apache.spark.sql.rapids.shims.ArrowUtilsShim import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.ColumnarBatch @@ -30,7 +30,8 @@ case class GpuGroupedPythonRunnerFactory( chainedFunc: Seq[ChainedPythonFunctions], argOffsets: Array[Array[Int]], dedupAttrs: StructType, - pythonOutputSchema: StructType) { + pythonOutputSchema: StructType, + evalType: Int) { // Configs from DB runtime val maxBytes = conf.pandasZeroConfConversionGroupbyApplyMaxBytesPerSlice val zeroConfEnabled = conf.pandasZeroConfConversionGroupbyApplyEnabled @@ -41,7 +42,7 @@ case class GpuGroupedPythonRunnerFactory( if (zeroConfEnabled && maxBytes > 0L) { new GpuGroupUDFArrowPythonRunner( chainedFunc, - PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + evalType, argOffsets, dedupAttrs, sessionLocalTimeZone, @@ -52,7 +53,7 @@ case class GpuGroupedPythonRunnerFactory( } else { new GpuArrowPythonRunner( chainedFunc, - PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + evalType, argOffsets, dedupAttrs, sessionLocalTimeZone, diff --git a/sql-plugin/src/main/spark341db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuGroupedPythonRunnerFactory.scala b/sql-plugin/src/main/spark341db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuGroupedPythonRunnerFactory.scala index f2297248711..b1dabbf5b5e 100644 --- a/sql-plugin/src/main/spark341db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuGroupedPythonRunnerFactory.scala +++ b/sql-plugin/src/main/spark341db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuGroupedPythonRunnerFactory.scala @@ -24,24 +24,25 @@ import org.apache.spark.sql.rapids.shims.ArrowUtilsShim import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch -//TODO is this needed? we already have a similar version in spark330db case class GpuGroupedPythonRunnerFactory( conf: org.apache.spark.sql.internal.SQLConf, chainedFunc: Seq[ChainedPythonFunctions], argOffsets: Array[Array[Int]], dedupAttrs: StructType, - pythonOutputSchema: StructType) { + pythonOutputSchema: StructType, + evalType: Int) { // Configs from DB runtime val maxBytes = conf.pandasZeroConfConversionGroupbyApplyMaxBytesPerSlice val zeroConfEnabled = conf.pandasZeroConfConversionGroupbyApplyEnabled + val isArrowBatchSlicingEnabled = conf.pythonArrowBatchSlicingEnabled val sessionLocalTimeZone = conf.sessionLocalTimeZone val pythonRunnerConf = ArrowUtilsShim.getPythonRunnerConfMap(conf) def getRunner(): GpuBasePythonRunner[ColumnarBatch] = { - if (zeroConfEnabled && maxBytes > 0L) { + if (isArrowBatchSlicingEnabled || (zeroConfEnabled && maxBytes > 0L)) { new GpuGroupUDFArrowPythonRunner( chainedFunc, - PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + evalType, argOffsets, dedupAttrs, sessionLocalTimeZone, @@ -52,7 +53,7 @@ case class GpuGroupedPythonRunnerFactory( } else { new GpuArrowPythonRunner( chainedFunc, - PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + evalType, argOffsets, dedupAttrs, sessionLocalTimeZone,