Skip to content

Commit

Permalink
Prep miscellaneous integ tests for Spark 4
Browse files Browse the repository at this point in the history
Fixes NVIDIA#11020. (grouping_sets_test.py)
Fixes NVIDIA#11023. (dpp_test.py)
Fixes NVIDIA#11025. (date_time_test.py)
Fixes NVIDIA#11026. (map_test.py)

This commit prepares miscellaneous integration tests to be run on Spark
4.

Certain integration tests fail on Spark 4 because of ANSI mode being
enabled by default.  This commit disables ANSI on the failing tests, or
introduces other fixes so that the tests may pass correctly.

Signed-off-by: MithunR <mithunr@nvidia.com>
  • Loading branch information
mythrocks committed Jun 25, 2024
1 parent 6455396 commit 9316347
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 4 deletions.
9 changes: 8 additions & 1 deletion integration_tests/src/main/python/date_time_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from conftest import is_utc, is_supported_time_zone, get_test_tz
from data_gen import *
from datetime import date, datetime, timezone
from marks import ignore_order, incompat, allow_non_gpu, datagen_overrides, tz_sensitive_test
from marks import allow_non_gpu, datagen_overrides, disable_ansi_mode, ignore_order, incompat, tz_sensitive_test
from pyspark.sql.types import *
from spark_session import with_cpu_session, is_before_spark_330, is_before_spark_350
import pyspark.sql.functions as f
Expand Down Expand Up @@ -91,6 +91,8 @@ def fun(spark):

assert_gpu_and_cpu_are_equal_collect(fun)


@disable_ansi_mode # ANSI mode tested separately.
# Should specify `spark.sql.legacy.interval.enabled` to test `DateAddInterval` after Spark 3.2.0,
# refer to https://issues.apache.org/jira/browse/SPARK-34896
# [SPARK-34896][SQL] Return day-time interval from dates subtraction
Expand Down Expand Up @@ -437,6 +439,8 @@ def test_string_unix_timestamp_ansi_exception():
error_message="Exception",
conf=ansi_enabled_conf)


@disable_ansi_mode # ANSI mode is tested separately.
@tz_sensitive_test
@pytest.mark.skipif(not is_supported_time_zone(), reason="not all time zones are supported now, refer to https://github.com/NVIDIA/spark-rapids/issues/6839, please update after all time zones are supported")
@pytest.mark.parametrize('parser_policy', ["CORRECTED", "EXCEPTION"], ids=idfn)
Expand Down Expand Up @@ -561,6 +565,8 @@ def test_date_format_maybe_incompat(data_gen, date_format):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr("date_format(a, '{}')".format(date_format)), conf)


@disable_ansi_mode # ANSI mode tested separately.
# Reproduce conditions for https://github.com/NVIDIA/spark-rapids/issues/5670
# where we had a failure due to GpuCast canonicalization with timezone.
# In this case it was doing filter after project, the way I get that to happen is by adding in the
Expand Down Expand Up @@ -594,6 +600,7 @@ def test_unsupported_fallback_date_format(data_gen):
conf)


@disable_ansi_mode # Failure cases for ANSI mode are tested separately.
@allow_non_gpu('ProjectExec')
def test_unsupported_fallback_to_date():
date_gen = StringGen(pattern="2023-08-01")
Expand Down
12 changes: 10 additions & 2 deletions integration_tests/src/main/python/dpp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from asserts import assert_cpu_and_gpu_are_equal_collect_with_capture, assert_gpu_and_cpu_are_equal_collect
from conftest import spark_tmp_table_factory
from data_gen import *
from marks import ignore_order, allow_non_gpu, datagen_overrides
from marks import ignore_order, allow_non_gpu, datagen_overrides, disable_ansi_mode
from spark_session import is_before_spark_320, with_cpu_session, is_before_spark_312, is_databricks_runtime, is_databricks113_or_later

# non-positive values here can produce a degenerative join, so here we ensure that most values are
Expand Down Expand Up @@ -167,7 +167,7 @@ def fn(spark):
'''
]


@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114
# When BroadcastExchangeExec is available on filtering side, and it can be reused:
# DynamicPruningExpression(InSubqueryExec(value, GpuSubqueryBroadcastExec)))
@ignore_order
Expand Down Expand Up @@ -198,6 +198,7 @@ def test_dpp_reuse_broadcast_exchange(spark_tmp_table_factory, store_format, s_i
conf=dict(_exchange_reuse_conf + [('spark.sql.adaptive.enabled', aqe_enabled)]))


@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114
# The SubqueryBroadcast can work on GPU even if the scan who holds it fallbacks into CPU.
@ignore_order
@pytest.mark.allow_non_gpu('FileSourceScanExec')
Expand All @@ -215,6 +216,7 @@ def test_dpp_reuse_broadcast_exchange_cpu_scan(spark_tmp_table_factory):
('spark.rapids.sql.format.parquet.read.enabled', 'false')]))


@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114
# When BroadcastExchange is not available and non-broadcast DPPs are forbidden, Spark will bypass it:
# DynamicPruningExpression(Literal.TrueLiteral)
@ignore_order
Expand All @@ -238,6 +240,7 @@ def test_dpp_bypass(spark_tmp_table_factory, store_format, s_index, aqe_enabled)
conf=dict(_bypass_conf + [('spark.sql.adaptive.enabled', aqe_enabled)]))


@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114
# When BroadcastExchange is not available, but it is still worthwhile to run DPP,
# then Spark will plan an extra Aggregate to collect filtering values:
# DynamicPruningExpression(InSubqueryExec(value, SubqueryExec(Aggregate(...))))
Expand All @@ -261,6 +264,7 @@ def test_dpp_via_aggregate_subquery(spark_tmp_table_factory, store_format, s_ind
conf=dict(_no_exchange_reuse_conf + [('spark.sql.adaptive.enabled', aqe_enabled)]))


@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114
# When BroadcastExchange is not available, Spark will skip DPP if there is no potential benefit
@ignore_order
@pytest.mark.parametrize('store_format', ['parquet', 'orc'], ids=idfn)
Expand Down Expand Up @@ -321,6 +325,8 @@ def create_dim_table_for_like(spark):
exist_classes,
conf=dict(_exchange_reuse_conf + [('spark.sql.adaptive.enabled', aqe_enabled)]))


@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114
# Test handling DPP expressions from a HashedRelation that rearranges columns
@pytest.mark.parametrize('aqe_enabled', [
'false',
Expand Down Expand Up @@ -351,6 +357,8 @@ def setup_tables(spark):
("spark.rapids.sql.castStringToTimestamp.enabled", "true"),
("spark.rapids.sql.hasExtendedYearValues", "false")]))


@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114
# Test handling DPP subquery that could broadcast EmptyRelation rather than a GPU serialized batch
@pytest.mark.parametrize('aqe_enabled', [
'false',
Expand Down
2 changes: 2 additions & 0 deletions integration_tests/src/main/python/grouping_sets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
'GROUP BY a, GROUPING SETS((a, b), (a), (), (a, b), (a), (b), ())',
]


@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114
# test nested syntax of grouping set, rollup and cube
@ignore_order
@pytest.mark.parametrize('data_gen', [_grouping_set_gen], ids=idfn)
Expand Down
7 changes: 6 additions & 1 deletion integration_tests/src/main/python/map_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from conftest import is_not_utc
from data_gen import *
from conftest import is_databricks_runtime
from marks import allow_non_gpu, ignore_order, datagen_overrides
from marks import allow_non_gpu, datagen_overrides, disable_ansi_mode, ignore_order
from spark_session import *
from pyspark.sql.functions import create_map, col, lit, row_number
from pyspark.sql.types import *
Expand Down Expand Up @@ -138,6 +138,7 @@ def test_get_map_value_string_keys(data_gen):
for key in numeric_key_gens for value in get_map_value_gens()]


@disable_ansi_mode # ANSI mode failures are tested separately.
@pytest.mark.parametrize('data_gen', numeric_key_map_gens, ids=idfn)
def test_get_map_value_numeric_keys(data_gen):
key_gen = data_gen._key_gen
Expand All @@ -151,6 +152,7 @@ def test_get_map_value_numeric_keys(data_gen):
'a[999]'))


@disable_ansi_mode # ANSI mode failures are tested separately.
@pytest.mark.parametrize('data_gen', supported_key_map_gens, ids=idfn)
@allow_non_gpu(*non_utc_allow)
def test_get_map_value_supported_keys(data_gen):
Expand All @@ -174,6 +176,7 @@ def test_get_map_value_fallback_keys(data_gen):
cpu_fallback_class_name="GetMapValue")


@disable_ansi_mode # ANSI mode failures are tested separately.
@pytest.mark.parametrize('key_gen', numeric_key_gens, ids=idfn)
def test_basic_scalar_map_get_map_value(key_gen):
def query_map_scalar(spark):
Expand Down Expand Up @@ -639,6 +642,8 @@ def test_map_element_at_ansi_null(data_gen):
'element_at(a, "NOT_FOUND")'),
conf=ansi_enabled_conf)


@disable_ansi_mode # ANSI mode failures are tested separately.
@pytest.mark.parametrize('data_gen', map_gens_sample, ids=idfn)
@allow_non_gpu(*non_utc_allow)
def test_transform_values(data_gen):
Expand Down

0 comments on commit 9316347

Please sign in to comment.