Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/branch-24.02' into rebalance-p…
Browse files Browse the repository at this point in the history
…remerge-time
  • Loading branch information
YanxuanLiu committed Dec 4, 2023
2 parents 9bef297 + 557680b commit 800b9fa
Show file tree
Hide file tree
Showing 74 changed files with 788 additions and 617 deletions.
6 changes: 4 additions & 2 deletions integration_tests/src/main/python/aqe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
from marks import ignore_order, allow_non_gpu
from spark_session import with_cpu_session, is_databricks113_or_later

# allow non gpu when time zone is non-UTC because of https://github.com/NVIDIA/spark-rapids/issues/9653'
not_utc_aqe_allow=['ShuffleExchangeExec', 'HashAggregateExec'] if is_not_utc() else []

_adaptive_conf = { "spark.sql.adaptive.enabled": "true" }

def create_skew_df(spark, length):
Expand Down Expand Up @@ -194,9 +197,8 @@ def do_it(spark):
# broadcast join. The bug currently manifests in Databricks, but could
# theoretically show up in other Spark distributions
@ignore_order(local=True)
@allow_non_gpu('BroadcastNestedLoopJoinExec', 'Cast', 'DateSub', *db_113_cpu_bnlj_join_allow)
@allow_non_gpu('BroadcastNestedLoopJoinExec', 'Cast', 'DateSub', *db_113_cpu_bnlj_join_allow, *not_utc_aqe_allow)
@pytest.mark.parametrize('join', joins, ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_aqe_join_reused_exchange_inequality_condition(spark_tmp_path, join):
data_path = spark_tmp_path + '/PARQUET_DATA'
def prep(spark):
Expand Down
3 changes: 0 additions & 3 deletions integration_tests/src/main/python/arithmetic_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import pytest

from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_error, assert_gpu_fallback_collect, assert_gpu_and_cpu_are_equal_sql
from conftest import is_not_utc
from data_gen import *
from marks import ignore_order, incompat, approximate_float, allow_non_gpu, datagen_overrides
from pyspark.sql.types import *
Expand Down Expand Up @@ -984,7 +983,6 @@ def test_columnar_pow(data_gen):
lambda spark : binary_op_df(spark, data_gen).selectExpr('pow(a, b)'))

@pytest.mark.parametrize('data_gen', all_basic_gens + _arith_decimal_gens, ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_least(data_gen):
num_cols = 20
s1 = with_cpu_session(
Expand All @@ -1001,7 +999,6 @@ def test_least(data_gen):
f.least(*command_args)))

@pytest.mark.parametrize('data_gen', all_basic_gens + _arith_decimal_gens, ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_greatest(data_gen):
num_cols = 20
s1 = with_cpu_session(
Expand Down
25 changes: 1 addition & 24 deletions integration_tests/src/main/python/array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_are_equal_sql, assert_gpu_and_cpu_error, assert_gpu_fallback_collect
from data_gen import *
from conftest import is_databricks_runtime, is_not_utc
from conftest import is_databricks_runtime
from marks import incompat
from spark_session import is_before_spark_313, is_before_spark_330, is_databricks113_or_later, is_spark_330_or_later, is_databricks104_or_later, is_spark_33X, is_spark_340_or_later, is_spark_330, is_spark_330cdh
from pyspark.sql.types import *
Expand Down Expand Up @@ -103,13 +103,11 @@

@pytest.mark.parametrize('data_gen', array_item_test_gens, ids=idfn)
@pytest.mark.parametrize('index_gen', array_index_gens, ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_array_item(data_gen, index_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: two_col_df(spark, data_gen, index_gen).selectExpr('a[b]'))

@pytest.mark.parametrize('data_gen', array_item_test_gens, ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_array_item_lit_ordinal(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen).selectExpr(
Expand Down Expand Up @@ -147,7 +145,6 @@ def test_array_item_with_strict_index(strict_index_enabled, index):

# No need to test this for multiple data types for array. Only one is enough, but with two kinds of invalid index.
@pytest.mark.parametrize('index', [-2, 100, array_neg_index_gen, array_out_index_gen], ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_array_item_ansi_fail_invalid_index(index):
message = "SparkArrayIndexOutOfBoundsException" if (is_databricks104_or_later() or is_spark_330_or_later()) else "java.lang.ArrayIndexOutOfBoundsException"
if isinstance(index, int):
Expand All @@ -174,7 +171,6 @@ def test_array_item_ansi_not_fail_all_null_data():
decimal_gen_32bit, decimal_gen_64bit, decimal_gen_128bit, binary_gen,
StructGen([['child0', StructGen([['child01', IntegerGen()]])], ['child1', string_gen], ['child2', float_gen]], nullable=False),
StructGen([['child0', byte_gen], ['child1', string_gen], ['child2', float_gen]], nullable=False)], ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_make_array(data_gen):
(s1, s2) = with_cpu_session(
lambda spark: gen_scalars_for_sql(data_gen, 2, force_no_nulls=not isinstance(data_gen, NullGen)))
Expand All @@ -187,7 +183,6 @@ def test_make_array(data_gen):


@pytest.mark.parametrize('data_gen', single_level_array_gens, ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_orderby_array_unique(data_gen):
assert_gpu_and_cpu_are_equal_sql(
lambda spark : append_unique_int_col_to_df(spark, unary_op_df(spark, data_gen)),
Expand Down Expand Up @@ -217,7 +212,6 @@ def test_orderby_array_of_structs(data_gen):
@pytest.mark.parametrize('data_gen', [byte_gen, short_gen, int_gen, long_gen,
float_gen, double_gen,
string_gen, boolean_gen, date_gen, timestamp_gen], ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_array_contains(data_gen):
arr_gen = ArrayGen(data_gen)
literal = with_cpu_session(lambda spark: gen_scalar(data_gen, force_no_nulls=True))
Expand Down Expand Up @@ -245,7 +239,6 @@ def test_array_contains_for_nans(data_gen):


@pytest.mark.parametrize('data_gen', array_item_test_gens, ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_array_element_at(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: two_col_df(spark, data_gen, array_no_zero_index_gen).selectExpr(
Expand Down Expand Up @@ -310,7 +303,6 @@ def test_array_element_at_zero_index_fail(index, ansi_enabled):


@pytest.mark.parametrize('data_gen', array_gens_sample, ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_array_transform(data_gen):
def do_it(spark):
columns = ['a', 'b',
Expand Down Expand Up @@ -345,7 +337,6 @@ def do_it(spark):
string_gen, boolean_gen, date_gen, timestamp_gen, null_gen] + decimal_gens

@pytest.mark.parametrize('data_gen', array_min_max_gens, ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_array_min_max(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, ArrayGen(data_gen)).selectExpr(
Expand All @@ -370,7 +361,6 @@ def test_array_concat_decimal(data_gen):
'concat(a, a)')))

@pytest.mark.parametrize('data_gen', orderable_gens + nested_gens_sample, ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_array_repeat_with_count_column(data_gen):
cnt_gen = IntegerGen(min_val=-5, max_val=5, special_cases=[])
cnt_not_null_gen = IntegerGen(min_val=-5, max_val=5, special_cases=[], nullable=False)
Expand All @@ -384,7 +374,6 @@ def test_array_repeat_with_count_column(data_gen):


@pytest.mark.parametrize('data_gen', orderable_gens + nested_gens_sample, ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_array_repeat_with_count_scalar(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen).selectExpr(
Expand Down Expand Up @@ -414,7 +403,6 @@ def test_sql_array_scalars(query):


@pytest.mark.parametrize('data_gen', all_basic_gens + nested_gens_sample, ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_get_array_struct_fields(data_gen):
array_struct_gen = ArrayGen(
StructGen([['child0', data_gen], ['child1', int_gen]]),
Expand Down Expand Up @@ -453,7 +441,6 @@ def do_it(spark):


@pytest.mark.parametrize('data_gen', array_zips_gen, ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_arrays_zip(data_gen):
gen = StructGen(
[('a', data_gen), ('b', data_gen), ('c', data_gen), ('d', data_gen)], nullable=False)
Expand Down Expand Up @@ -486,7 +473,6 @@ def q1(spark):

@incompat
@pytest.mark.parametrize('data_gen', no_neg_zero_all_basic_gens + decimal_gens, ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
@pytest.mark.skipif(is_before_spark_313() or is_spark_330() or is_spark_330cdh(), reason="NaN equality is only handled in Spark 3.1.3+ and SPARK-39976 issue with null and ArrayIntersect in Spark 3.3.0")
def test_array_intersect(data_gen):
gen = StructGen(
Expand Down Expand Up @@ -528,7 +514,6 @@ def test_array_intersect_spark330(data_gen):
@incompat
@pytest.mark.parametrize('data_gen', no_neg_zero_all_basic_gens_no_nans + decimal_gens, ids=idfn)
@pytest.mark.skipif(not is_before_spark_313(), reason="NaN equality is only handled in Spark 3.1.3+")
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_array_intersect_before_spark313(data_gen):
gen = StructGen(
[('a', ArrayGen(data_gen, nullable=True)),
Expand All @@ -549,7 +534,6 @@ def test_array_intersect_before_spark313(data_gen):
@incompat
@pytest.mark.parametrize('data_gen', no_neg_zero_all_basic_gens + decimal_gens, ids=idfn)
@pytest.mark.skipif(is_before_spark_313(), reason="NaN equality is only handled in Spark 3.1.3+")
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_array_union(data_gen):
gen = StructGen(
[('a', ArrayGen(data_gen, nullable=True)),
Expand All @@ -570,7 +554,6 @@ def test_array_union(data_gen):
@incompat
@pytest.mark.parametrize('data_gen', no_neg_zero_all_basic_gens_no_nans + decimal_gens, ids=idfn)
@pytest.mark.skipif(not is_before_spark_313(), reason="NaN equality is only handled in Spark 3.1.3+")
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_array_union_before_spark313(data_gen):
gen = StructGen(
[('a', ArrayGen(data_gen, nullable=True)),
Expand All @@ -591,7 +574,6 @@ def test_array_union_before_spark313(data_gen):
@incompat
@pytest.mark.parametrize('data_gen', no_neg_zero_all_basic_gens + decimal_gens, ids=idfn)
@pytest.mark.skipif(is_before_spark_313(), reason="NaN equality is only handled in Spark 3.1.3+")
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_array_except(data_gen):
gen = StructGen(
[('a', ArrayGen(data_gen, nullable=True)),
Expand All @@ -612,7 +594,6 @@ def test_array_except(data_gen):
@incompat
@pytest.mark.parametrize('data_gen', no_neg_zero_all_basic_gens_no_nans + decimal_gens, ids=idfn)
@pytest.mark.skipif(not is_before_spark_313(), reason="NaN equality is only handled in Spark 3.1.3+")
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_array_except_before_spark313(data_gen):
gen = StructGen(
[('a', ArrayGen(data_gen, nullable=True)),
Expand All @@ -633,7 +614,6 @@ def test_array_except_before_spark313(data_gen):
@incompat
@pytest.mark.parametrize('data_gen', no_neg_zero_all_basic_gens + decimal_gens, ids=idfn)
@pytest.mark.skipif(is_before_spark_313(), reason="NaN equality is only handled in Spark 3.1.3+")
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_arrays_overlap(data_gen):
gen = StructGen(
[('a', ArrayGen(data_gen, nullable=True)),
Expand All @@ -655,7 +635,6 @@ def test_arrays_overlap(data_gen):
@incompat
@pytest.mark.parametrize('data_gen', no_neg_zero_all_basic_gens_no_nans + decimal_gens, ids=idfn)
@pytest.mark.skipif(not is_before_spark_313(), reason="NaN equality is only handled in Spark 3.1.3+")
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_arrays_overlap_before_spark313(data_gen):
gen = StructGen(
[('a', ArrayGen(data_gen, nullable=True)),
Expand Down Expand Up @@ -693,7 +672,6 @@ def test_array_remove_scalar(data_gen):
FloatGen(special_cases=_non_neg_zero_float_special_cases + [-0.0]),
DoubleGen(special_cases=_non_neg_zero_double_special_cases + [-0.0]),
StringGen(pattern='[0-9]{1,5}'), boolean_gen, date_gen, timestamp_gen] + decimal_gens, ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_array_remove(data_gen):
gen = StructGen(
[('a', ArrayGen(data_gen, nullable=True)),
Expand All @@ -708,7 +686,6 @@ def test_array_remove(data_gen):


@pytest.mark.parametrize('data_gen', [ArrayGen(sub_gen) for sub_gen in array_gens_sample], ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_flatten_array(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen).selectExpr('flatten(a)')
Expand Down
9 changes: 0 additions & 9 deletions integration_tests/src/main/python/ast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import pytest

from asserts import assert_cpu_and_gpu_are_equal_collect_with_capture
from conftest import is_not_utc
from data_gen import *
from marks import approximate_float, datagen_overrides
from spark_session import with_cpu_session, is_before_spark_330
Expand Down Expand Up @@ -71,7 +70,6 @@ def assert_binary_ast(data_descr, func, conf={}):
assert_gpu_ast(is_supported, lambda spark: func(binary_op_df(spark, data_gen)), conf=conf)

@pytest.mark.parametrize('data_gen', [boolean_gen, byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, timestamp_gen, date_gen], ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_literal(spark_tmp_path, data_gen):
# Write data to Parquet so Spark generates a plan using just the count of the data.
data_path = spark_tmp_path + '/AST_TEST_DATA'
Expand All @@ -81,7 +79,6 @@ def test_literal(spark_tmp_path, data_gen):
func=lambda spark: spark.read.parquet(data_path).select(scalar))

@pytest.mark.parametrize('data_gen', [boolean_gen, byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, timestamp_gen, date_gen], ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_null_literal(spark_tmp_path, data_gen):
# Write data to Parquet so Spark generates a plan using just the count of the data.
data_path = spark_tmp_path + '/AST_TEST_DATA'
Expand Down Expand Up @@ -235,7 +232,6 @@ def test_expm1(data_descr):
assert_unary_ast(data_descr, lambda df: df.selectExpr('expm1(a)'))

@pytest.mark.parametrize('data_descr', ast_comparable_descrs, ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_eq(data_descr):
(s1, s2) = with_cpu_session(lambda spark: gen_scalars(data_descr[0], 2))
assert_binary_ast(data_descr,
Expand All @@ -245,7 +241,6 @@ def test_eq(data_descr):
f.col('a') == f.col('b')))

@pytest.mark.parametrize('data_descr', ast_comparable_descrs, ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_ne(data_descr):
(s1, s2) = with_cpu_session(lambda spark: gen_scalars(data_descr[0], 2))
assert_binary_ast(data_descr,
Expand All @@ -255,7 +250,6 @@ def test_ne(data_descr):
f.col('a') != f.col('b')))

@pytest.mark.parametrize('data_descr', ast_comparable_descrs, ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_lt(data_descr):
(s1, s2) = with_cpu_session(lambda spark: gen_scalars(data_descr[0], 2))
assert_binary_ast(data_descr,
Expand All @@ -265,7 +259,6 @@ def test_lt(data_descr):
f.col('a') < f.col('b')))

@pytest.mark.parametrize('data_descr', ast_comparable_descrs, ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_lte(data_descr):
(s1, s2) = with_cpu_session(lambda spark: gen_scalars(data_descr[0], 2))
assert_binary_ast(data_descr,
Expand All @@ -275,7 +268,6 @@ def test_lte(data_descr):
f.col('a') <= f.col('b')))

@pytest.mark.parametrize('data_descr', ast_comparable_descrs, ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_gt(data_descr):
(s1, s2) = with_cpu_session(lambda spark: gen_scalars(data_descr[0], 2))
assert_binary_ast(data_descr,
Expand All @@ -285,7 +277,6 @@ def test_gt(data_descr):
f.col('a') > f.col('b')))

@pytest.mark.parametrize('data_descr', ast_comparable_descrs, ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_gte(data_descr):
(s1, s2) = with_cpu_session(lambda spark: gen_scalars(data_descr[0], 2))
assert_binary_ast(data_descr,
Expand Down
Loading

0 comments on commit 800b9fa

Please sign in to comment.