diff --git a/integration_tests/src/main/python/udf_test.py b/integration_tests/src/main/python/udf_test.py index b7b6573ab14..19a3c9f4c9a 100644 --- a/integration_tests/src/main/python/udf_test.py +++ b/integration_tests/src/main/python/udf_test.py @@ -15,6 +15,7 @@ import pytest from conftest import is_at_least_precommit_run +from spark_session import is_databricks91_or_later from pyspark.sql.pandas.utils import require_minimum_pyarrow_version, require_minimum_pandas_version try: @@ -212,6 +213,7 @@ def pandas_sum(to_process: pd.Series) -> list: # ======= Test flat map group in Pandas ======= +@pytest.mark.skipif(is_databricks91_or_later(), reason="https://github.com/NVIDIA/spark-rapids/issues/4599") @ignore_order @pytest.mark.parametrize('data_gen', [LongGen()], ids=idfn) def test_group_apply_udf(data_gen): @@ -226,6 +228,7 @@ def pandas_add(data): conf=arrow_udf_conf) +@pytest.mark.skipif(is_databricks91_or_later(), reason="https://github.com/NVIDIA/spark-rapids/issues/4599") @ignore_order @pytest.mark.parametrize('data_gen', arrow_common_gen, ids=idfn) def test_group_apply_udf_more_types(data_gen): diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 345a0e03edd..2b6cc3250c1 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -497,21 +497,6 @@ object GpuOverrides extends Logging { !regexList.exists(pattern => s.contains(pattern)) } - private def convertExprToGpuIfPossible(expr: Expression, conf: RapidsConf): Expression = { - if (expr.find(_.isInstanceOf[GpuExpression]).isDefined) { - // already been converted - expr - } else { - val wrapped = wrapExpr(expr, conf, None) - wrapped.tagForGpu() - if (wrapped.canExprTreeBeReplaced) { - wrapped.convertToGpu() - } else { - expr - } - } - } - private def convertPartToGpuIfPossible(part: Partitioning, conf: RapidsConf): Partitioning = { part match { case _: GpuPartitioning => part @@ -910,7 +895,6 @@ object GpuOverrides extends Logging { private[this] def extractOrigParam(expr: BaseExprMeta[_]): BaseExprMeta[_] = expr.wrapped match { case lit: Literal if lit.dataType.isInstanceOf[DecimalType] => - val dt = lit.dataType.asInstanceOf[DecimalType] // Lets figure out if we can make the Literal value smaller val (newType, value) = lit.value match { case null => @@ -930,8 +914,10 @@ object GpuOverrides extends Logging { throw new IllegalArgumentException(s"Unexpected decimal literal value $other") } expr.asInstanceOf[LiteralExprMeta].withNewLiteral(Literal(value, newType)) - // We avoid unapply for Cast because it changes between versions of Spark - case PromotePrecision(c: CastBase) if c.dataType.isInstanceOf[DecimalType] => + // Avoid unapply for PromotePrecision and Cast because it changes between Spark versions + case p: PromotePrecision if p.child.isInstanceOf[CastBase] && + p.child.dataType.isInstanceOf[DecimalType] => + val c = p.child.asInstanceOf[CastBase] val to = c.dataType.asInstanceOf[DecimalType] val fromType = DecimalUtil.optionallyAsDecimalType(c.child.dataType) fromType match {