From d2b6bfc6c0effc98bebdbe58c7e2d8f5a25887c6 Mon Sep 17 00:00:00 2001 From: Liangcai Li Date: Wed, 24 Feb 2021 23:56:13 +0800 Subject: [PATCH] Build python output schema from udf expressions (#1794) Signed-off-by: Firestarman --- integration_tests/src/main/python/udf_test.py | 4 +--- .../execution/python/GpuArrowEvalPythonExec.scala | 7 ++++++- .../execution/python/GpuWindowInPandasExecBase.scala | 10 ++++++++-- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/integration_tests/src/main/python/udf_test.py b/integration_tests/src/main/python/udf_test.py index 74cae248c3b..085995874fb 100644 --- a/integration_tests/src/main/python/udf_test.py +++ b/integration_tests/src/main/python/udf_test.py @@ -14,7 +14,7 @@ import pytest -from conftest import is_at_least_precommit_run, is_databricks_runtime +from conftest import is_at_least_precommit_run from pyspark.sql.pandas.utils import require_minimum_pyarrow_version, require_minimum_pandas_version try: @@ -170,8 +170,6 @@ def pandas_sum(to_process: pd.Series) -> int: conf=arrow_udf_conf) -@pytest.mark.xfail(condition=is_databricks_runtime(), - reason='https://github.com/NVIDIA/spark-rapids/issues/1644') @ignore_order @pytest.mark.parametrize('data_gen', [byte_gen, short_gen, int_gen], ids=idfn) @pytest.mark.parametrize('window', udf_windows, ids=window_ids) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuArrowEvalPythonExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuArrowEvalPythonExec.scala index c1175617ac0..5ebfcafc965 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuArrowEvalPythonExec.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuArrowEvalPythonExec.scala @@ -555,7 +555,12 @@ case class GpuArrowEvalPythonExec( // cache in a local to avoid serializing the plan val inputSchema = child.output.toStructType - val pythonOutputSchema = StructType.fromAttributes(resultAttrs) + // Build the Python output schema from UDF expressions instead of the 'resultAttrs', because + // the 'resultAttrs' is NOT always equal to the Python output schema. For example, + // On Databricks when projecting only one column from a Python UDF output where containing + // multiple result columns, there will be only one attribute in the 'resultAttrs' for the + // projecting output, but the output schema for this Python UDF contains multiple columns. + val pythonOutputSchema = StructType.fromAttributes(udfs.map(_.resultAttribute)) val childOutput = child.output val targetBatchSize = batchSize diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuWindowInPandasExecBase.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuWindowInPandasExecBase.scala index b63c3009c2a..9f6ed079652 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuWindowInPandasExecBase.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuWindowInPandasExecBase.scala @@ -490,8 +490,14 @@ trait GpuWindowInPandasExecBase extends UnaryExecNode with GpuExec { lazy val isPythonOnGpuEnabled = GpuPythonHelper.isPythonOnGpuEnabled(conf, pythonModuleKey) // cache in a local to avoid serializing the plan - val retAttributes = windowExpression.map(_.asInstanceOf[NamedExpression].toAttribute) - val pythonOutputSchema = StructType.fromAttributes(retAttributes) + + // Build the Python output schema from UDF expressions instead of the 'windowExpression', + // because the 'windowExpression' does NOT always represent the Python output schema. + // For example, on Databricks when projecting only one column from a Python UDF output + // where containing multiple result columns, there will be only one item in the + // 'windowExpression' for the projecting output, but the output schema for this Python + // UDF contains multiple columns. + val pythonOutputSchema = StructType.fromAttributes(udfExpressions.map(_.resultAttribute)) val childOutput = child.output // 8) Start processing.