Skip to content

Commit

Permalink
Merge pull request #4604 from NVIDIA/branch-22.02
Browse files Browse the repository at this point in the history
[auto-merge] branch-22.02 to branch-22.04 [skip ci] [bot]
  • Loading branch information
nvauto authored Jan 21, 2022
2 parents e9aa25c + 2090ecc commit d6e5eb7
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 18 deletions.
3 changes: 3 additions & 0 deletions integration_tests/src/main/python/udf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import pytest

from conftest import is_at_least_precommit_run
from spark_session import is_databricks91_or_later

from pyspark.sql.pandas.utils import require_minimum_pyarrow_version, require_minimum_pandas_version
try:
Expand Down Expand Up @@ -212,6 +213,7 @@ def pandas_sum(to_process: pd.Series) -> list:


# ======= Test flat map group in Pandas =======
@pytest.mark.skipif(is_databricks91_or_later(), reason="https://github.com/NVIDIA/spark-rapids/issues/4599")
@ignore_order
@pytest.mark.parametrize('data_gen', [LongGen()], ids=idfn)
def test_group_apply_udf(data_gen):
Expand All @@ -226,6 +228,7 @@ def pandas_add(data):
conf=arrow_udf_conf)


@pytest.mark.skipif(is_databricks91_or_later(), reason="https://github.com/NVIDIA/spark-rapids/issues/4599")
@ignore_order
@pytest.mark.parametrize('data_gen', arrow_common_gen, ids=idfn)
def test_group_apply_udf_more_types(data_gen):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -497,21 +497,6 @@ object GpuOverrides extends Logging {
!regexList.exists(pattern => s.contains(pattern))
}

private def convertExprToGpuIfPossible(expr: Expression, conf: RapidsConf): Expression = {
if (expr.find(_.isInstanceOf[GpuExpression]).isDefined) {
// already been converted
expr
} else {
val wrapped = wrapExpr(expr, conf, None)
wrapped.tagForGpu()
if (wrapped.canExprTreeBeReplaced) {
wrapped.convertToGpu()
} else {
expr
}
}
}

private def convertPartToGpuIfPossible(part: Partitioning, conf: RapidsConf): Partitioning = {
part match {
case _: GpuPartitioning => part
Expand Down Expand Up @@ -910,7 +895,6 @@ object GpuOverrides extends Logging {
private[this] def extractOrigParam(expr: BaseExprMeta[_]): BaseExprMeta[_] =
expr.wrapped match {
case lit: Literal if lit.dataType.isInstanceOf[DecimalType] =>
val dt = lit.dataType.asInstanceOf[DecimalType]
// Lets figure out if we can make the Literal value smaller
val (newType, value) = lit.value match {
case null =>
Expand All @@ -930,8 +914,10 @@ object GpuOverrides extends Logging {
throw new IllegalArgumentException(s"Unexpected decimal literal value $other")
}
expr.asInstanceOf[LiteralExprMeta].withNewLiteral(Literal(value, newType))
// We avoid unapply for Cast because it changes between versions of Spark
case PromotePrecision(c: CastBase) if c.dataType.isInstanceOf[DecimalType] =>
// Avoid unapply for PromotePrecision and Cast because it changes between Spark versions
case p: PromotePrecision if p.child.isInstanceOf[CastBase] &&
p.child.dataType.isInstanceOf[DecimalType] =>
val c = p.child.asInstanceOf[CastBase]
val to = c.dataType.asInstanceOf[DecimalType]
val fromType = DecimalUtil.optionallyAsDecimalType(c.child.dataType)
fromType match {
Expand Down

0 comments on commit d6e5eb7

Please sign in to comment.