Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support fine grained timezone checker instead of type based [databricks] #9719

Merged
merged 50 commits into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
801752c
Support current_date with timezone
winningsix Nov 15, 2023
36ddb24
Remove type based timezone checker
winningsix Nov 17, 2023
28ba19e
Add timezone checker for expressions
winningsix Nov 17, 2023
3d5d297
Add check for cast
winningsix Nov 17, 2023
b3f85bd
code style fix
winningsix Nov 17, 2023
c7c60d9
Fix premerge fails
winningsix Nov 20, 2023
90d975e
Fix
winningsix Nov 20, 2023
b88a13b
Refine comments
winningsix Nov 21, 2023
3cda255
Refine comments
winningsix Nov 21, 2023
42c7888
Refactor
winningsix Nov 21, 2023
1da1291
Refine comments
winningsix Nov 22, 2023
46dbe60
Typo
winningsix Nov 22, 2023
c954125
Re-enable failed test cases
winningsix Nov 27, 2023
eb32703
Fix inmatch cases
winningsix Nov 27, 2023
f0f6164
Revert before commit
winningsix Nov 27, 2023
e554f4e
Fix
winningsix Nov 27, 2023
d9fe752
Fix
winningsix Nov 28, 2023
78e5804
Comments
winningsix Nov 28, 2023
2ad72c8
Fix
winningsix Nov 28, 2023
7c734c0
Fix failed cases
winningsix Nov 28, 2023
e62ee80
Change xfail to allow_non_gpu
Nov 28, 2023
568f8e2
Fix CSV scan
winningsix Nov 29, 2023
7a70ad0
Merge branch 'branch-23.12' into new_now
winningsix Nov 29, 2023
a743a7a
Fix
Nov 28, 2023
4942496
Fix explain on CPU
winningsix Nov 29, 2023
83171ae
Fix json
winningsix Nov 29, 2023
e9b3b10
Fix json
winningsix Nov 29, 2023
a564891
Fix ORC scan
winningsix Nov 29, 2023
f497490
Fix ORC test
winningsix Nov 29, 2023
63d9e26
skip legacy mode rebase
winningsix Nov 29, 2023
1cbb694
Support check for AnsiCast
winningsix Nov 29, 2023
07b6863
Fix cases
Nov 29, 2023
36cc096
Fix more cases
Nov 29, 2023
3e70b52
Fix
winningsix Nov 29, 2023
51e5017
Refactor
winningsix Nov 29, 2023
07819fb
Fix more cases 3
Nov 29, 2023
d900607
Address comments
winningsix Nov 29, 2023
6e38bcb
Address comments
winningsix Nov 29, 2023
9fbe5b7
Merge branch 'branch-23.12' into new_now
winningsix Nov 29, 2023
dd07316
Fix for 341
winningsix Nov 30, 2023
2e6578e
Fix 341
winningsix Nov 30, 2023
33187b0
Minor fix
winningsix Nov 30, 2023
6dff012
Enable golden configuration
winningsix Nov 30, 2023
63d4394
Merge branch 'branch-24.02' into new_now
winningsix Nov 30, 2023
4d29350
Fix UTC cases
winningsix Nov 30, 2023
950ac3c
Address comments
winningsix Nov 30, 2023
eb0c85e
Address comments
winningsix Dec 1, 2023
78c026e
Merge branch 'branch-24.02' into new_now
winningsix Dec 1, 2023
0267c81
Fix a merge issue
winningsix Dec 1, 2023
216daf3
Minor fix
winningsix Dec 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions integration_tests/src/main/python/aqe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
from marks import ignore_order, allow_non_gpu
from spark_session import with_cpu_session, is_databricks113_or_later

# allow non gpu when time zone is non-UTC because of https://github.com/NVIDIA/spark-rapids/issues/9653'
not_utc_aqe_allow=['ShuffleExchangeExec', 'HashAggregateExec'] if is_not_utc() else []

_adaptive_conf = { "spark.sql.adaptive.enabled": "true" }

def create_skew_df(spark, length):
Expand Down Expand Up @@ -194,9 +197,8 @@ def do_it(spark):
# broadcast join. The bug currently manifests in Databricks, but could
# theoretically show up in other Spark distributions
@ignore_order(local=True)
@allow_non_gpu('BroadcastNestedLoopJoinExec', 'Cast', 'DateSub', *db_113_cpu_bnlj_join_allow)
@allow_non_gpu('BroadcastNestedLoopJoinExec', 'Cast', 'DateSub', *db_113_cpu_bnlj_join_allow, *not_utc_aqe_allow)
@pytest.mark.parametrize('join', joins, ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_aqe_join_reused_exchange_inequality_condition(spark_tmp_path, join):
data_path = spark_tmp_path + '/PARQUET_DATA'
def prep(spark):
Expand Down
3 changes: 0 additions & 3 deletions integration_tests/src/main/python/arithmetic_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import pytest

from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_error, assert_gpu_fallback_collect, assert_gpu_and_cpu_are_equal_sql
from conftest import is_not_utc
from data_gen import *
from marks import ignore_order, incompat, approximate_float, allow_non_gpu, datagen_overrides
from pyspark.sql.types import *
Expand Down Expand Up @@ -984,7 +983,6 @@ def test_columnar_pow(data_gen):
lambda spark : binary_op_df(spark, data_gen).selectExpr('pow(a, b)'))

@pytest.mark.parametrize('data_gen', all_basic_gens + _arith_decimal_gens, ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_least(data_gen):
num_cols = 20
s1 = with_cpu_session(
Expand All @@ -1001,7 +999,6 @@ def test_least(data_gen):
f.least(*command_args)))

@pytest.mark.parametrize('data_gen', all_basic_gens + _arith_decimal_gens, ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_greatest(data_gen):
num_cols = 20
s1 = with_cpu_session(
Expand Down
25 changes: 1 addition & 24 deletions integration_tests/src/main/python/array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_are_equal_sql, assert_gpu_and_cpu_error, assert_gpu_fallback_collect
from data_gen import *
from conftest import is_databricks_runtime, is_not_utc
from conftest import is_databricks_runtime
from marks import incompat
from spark_session import is_before_spark_313, is_before_spark_330, is_databricks113_or_later, is_spark_330_or_later, is_databricks104_or_later, is_spark_33X, is_spark_340_or_later, is_spark_330, is_spark_330cdh
from pyspark.sql.types import *
Expand Down Expand Up @@ -103,13 +103,11 @@

@pytest.mark.parametrize('data_gen', array_item_test_gens, ids=idfn)
@pytest.mark.parametrize('index_gen', array_index_gens, ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_array_item(data_gen, index_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: two_col_df(spark, data_gen, index_gen).selectExpr('a[b]'))

@pytest.mark.parametrize('data_gen', array_item_test_gens, ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_array_item_lit_ordinal(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen).selectExpr(
Expand Down Expand Up @@ -147,7 +145,6 @@ def test_array_item_with_strict_index(strict_index_enabled, index):

# No need to test this for multiple data types for array. Only one is enough, but with two kinds of invalid index.
@pytest.mark.parametrize('index', [-2, 100, array_neg_index_gen, array_out_index_gen], ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_array_item_ansi_fail_invalid_index(index):
message = "SparkArrayIndexOutOfBoundsException" if (is_databricks104_or_later() or is_spark_330_or_later()) else "java.lang.ArrayIndexOutOfBoundsException"
if isinstance(index, int):
Expand All @@ -174,7 +171,6 @@ def test_array_item_ansi_not_fail_all_null_data():
decimal_gen_32bit, decimal_gen_64bit, decimal_gen_128bit, binary_gen,
StructGen([['child0', StructGen([['child01', IntegerGen()]])], ['child1', string_gen], ['child2', float_gen]], nullable=False),
StructGen([['child0', byte_gen], ['child1', string_gen], ['child2', float_gen]], nullable=False)], ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_make_array(data_gen):
(s1, s2) = with_cpu_session(
lambda spark: gen_scalars_for_sql(data_gen, 2, force_no_nulls=not isinstance(data_gen, NullGen)))
Expand All @@ -187,7 +183,6 @@ def test_make_array(data_gen):


@pytest.mark.parametrize('data_gen', single_level_array_gens, ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_orderby_array_unique(data_gen):
assert_gpu_and_cpu_are_equal_sql(
lambda spark : append_unique_int_col_to_df(spark, unary_op_df(spark, data_gen)),
Expand Down Expand Up @@ -217,7 +212,6 @@ def test_orderby_array_of_structs(data_gen):
@pytest.mark.parametrize('data_gen', [byte_gen, short_gen, int_gen, long_gen,
float_gen, double_gen,
string_gen, boolean_gen, date_gen, timestamp_gen], ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_array_contains(data_gen):
arr_gen = ArrayGen(data_gen)
literal = with_cpu_session(lambda spark: gen_scalar(data_gen, force_no_nulls=True))
Expand Down Expand Up @@ -245,7 +239,6 @@ def test_array_contains_for_nans(data_gen):


@pytest.mark.parametrize('data_gen', array_item_test_gens, ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_array_element_at(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: two_col_df(spark, data_gen, array_no_zero_index_gen).selectExpr(
Expand Down Expand Up @@ -310,7 +303,6 @@ def test_array_element_at_zero_index_fail(index, ansi_enabled):


@pytest.mark.parametrize('data_gen', array_gens_sample, ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_array_transform(data_gen):
def do_it(spark):
columns = ['a', 'b',
Expand Down Expand Up @@ -345,7 +337,6 @@ def do_it(spark):
string_gen, boolean_gen, date_gen, timestamp_gen, null_gen] + decimal_gens

@pytest.mark.parametrize('data_gen', array_min_max_gens, ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_array_min_max(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, ArrayGen(data_gen)).selectExpr(
Expand All @@ -370,7 +361,6 @@ def test_array_concat_decimal(data_gen):
'concat(a, a)')))

@pytest.mark.parametrize('data_gen', orderable_gens + nested_gens_sample, ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_array_repeat_with_count_column(data_gen):
cnt_gen = IntegerGen(min_val=-5, max_val=5, special_cases=[])
cnt_not_null_gen = IntegerGen(min_val=-5, max_val=5, special_cases=[], nullable=False)
Expand All @@ -384,7 +374,6 @@ def test_array_repeat_with_count_column(data_gen):


@pytest.mark.parametrize('data_gen', orderable_gens + nested_gens_sample, ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_array_repeat_with_count_scalar(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen).selectExpr(
Expand Down Expand Up @@ -414,7 +403,6 @@ def test_sql_array_scalars(query):


@pytest.mark.parametrize('data_gen', all_basic_gens + nested_gens_sample, ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_get_array_struct_fields(data_gen):
array_struct_gen = ArrayGen(
StructGen([['child0', data_gen], ['child1', int_gen]]),
Expand Down Expand Up @@ -453,7 +441,6 @@ def do_it(spark):


@pytest.mark.parametrize('data_gen', array_zips_gen, ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_arrays_zip(data_gen):
gen = StructGen(
[('a', data_gen), ('b', data_gen), ('c', data_gen), ('d', data_gen)], nullable=False)
Expand Down Expand Up @@ -486,7 +473,6 @@ def q1(spark):

@incompat
@pytest.mark.parametrize('data_gen', no_neg_zero_all_basic_gens + decimal_gens, ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
@pytest.mark.skipif(is_before_spark_313() or is_spark_330() or is_spark_330cdh(), reason="NaN equality is only handled in Spark 3.1.3+ and SPARK-39976 issue with null and ArrayIntersect in Spark 3.3.0")
def test_array_intersect(data_gen):
gen = StructGen(
Expand Down Expand Up @@ -528,7 +514,6 @@ def test_array_intersect_spark330(data_gen):
@incompat
@pytest.mark.parametrize('data_gen', no_neg_zero_all_basic_gens_no_nans + decimal_gens, ids=idfn)
@pytest.mark.skipif(not is_before_spark_313(), reason="NaN equality is only handled in Spark 3.1.3+")
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_array_intersect_before_spark313(data_gen):
gen = StructGen(
[('a', ArrayGen(data_gen, nullable=True)),
Expand All @@ -549,7 +534,6 @@ def test_array_intersect_before_spark313(data_gen):
@incompat
@pytest.mark.parametrize('data_gen', no_neg_zero_all_basic_gens + decimal_gens, ids=idfn)
@pytest.mark.skipif(is_before_spark_313(), reason="NaN equality is only handled in Spark 3.1.3+")
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_array_union(data_gen):
gen = StructGen(
[('a', ArrayGen(data_gen, nullable=True)),
Expand All @@ -570,7 +554,6 @@ def test_array_union(data_gen):
@incompat
@pytest.mark.parametrize('data_gen', no_neg_zero_all_basic_gens_no_nans + decimal_gens, ids=idfn)
@pytest.mark.skipif(not is_before_spark_313(), reason="NaN equality is only handled in Spark 3.1.3+")
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_array_union_before_spark313(data_gen):
gen = StructGen(
[('a', ArrayGen(data_gen, nullable=True)),
Expand All @@ -591,7 +574,6 @@ def test_array_union_before_spark313(data_gen):
@incompat
@pytest.mark.parametrize('data_gen', no_neg_zero_all_basic_gens + decimal_gens, ids=idfn)
@pytest.mark.skipif(is_before_spark_313(), reason="NaN equality is only handled in Spark 3.1.3+")
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_array_except(data_gen):
gen = StructGen(
[('a', ArrayGen(data_gen, nullable=True)),
Expand All @@ -612,7 +594,6 @@ def test_array_except(data_gen):
@incompat
@pytest.mark.parametrize('data_gen', no_neg_zero_all_basic_gens_no_nans + decimal_gens, ids=idfn)
@pytest.mark.skipif(not is_before_spark_313(), reason="NaN equality is only handled in Spark 3.1.3+")
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_array_except_before_spark313(data_gen):
gen = StructGen(
[('a', ArrayGen(data_gen, nullable=True)),
Expand All @@ -633,7 +614,6 @@ def test_array_except_before_spark313(data_gen):
@incompat
@pytest.mark.parametrize('data_gen', no_neg_zero_all_basic_gens + decimal_gens, ids=idfn)
@pytest.mark.skipif(is_before_spark_313(), reason="NaN equality is only handled in Spark 3.1.3+")
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_arrays_overlap(data_gen):
gen = StructGen(
[('a', ArrayGen(data_gen, nullable=True)),
Expand All @@ -655,7 +635,6 @@ def test_arrays_overlap(data_gen):
@incompat
@pytest.mark.parametrize('data_gen', no_neg_zero_all_basic_gens_no_nans + decimal_gens, ids=idfn)
@pytest.mark.skipif(not is_before_spark_313(), reason="NaN equality is only handled in Spark 3.1.3+")
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_arrays_overlap_before_spark313(data_gen):
gen = StructGen(
[('a', ArrayGen(data_gen, nullable=True)),
Expand Down Expand Up @@ -693,7 +672,6 @@ def test_array_remove_scalar(data_gen):
FloatGen(special_cases=_non_neg_zero_float_special_cases + [-0.0]),
DoubleGen(special_cases=_non_neg_zero_double_special_cases + [-0.0]),
StringGen(pattern='[0-9]{1,5}'), boolean_gen, date_gen, timestamp_gen] + decimal_gens, ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_array_remove(data_gen):
gen = StructGen(
[('a', ArrayGen(data_gen, nullable=True)),
Expand All @@ -708,7 +686,6 @@ def test_array_remove(data_gen):


@pytest.mark.parametrize('data_gen', [ArrayGen(sub_gen) for sub_gen in array_gens_sample], ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_flatten_array(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen).selectExpr('flatten(a)')
Expand Down
9 changes: 0 additions & 9 deletions integration_tests/src/main/python/ast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import pytest

from asserts import assert_cpu_and_gpu_are_equal_collect_with_capture
from conftest import is_not_utc
from data_gen import *
from marks import approximate_float, datagen_overrides
from spark_session import with_cpu_session, is_before_spark_330
Expand Down Expand Up @@ -71,7 +70,6 @@ def assert_binary_ast(data_descr, func, conf={}):
assert_gpu_ast(is_supported, lambda spark: func(binary_op_df(spark, data_gen)), conf=conf)

@pytest.mark.parametrize('data_gen', [boolean_gen, byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, timestamp_gen, date_gen], ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_literal(spark_tmp_path, data_gen):
# Write data to Parquet so Spark generates a plan using just the count of the data.
data_path = spark_tmp_path + '/AST_TEST_DATA'
Expand All @@ -81,7 +79,6 @@ def test_literal(spark_tmp_path, data_gen):
func=lambda spark: spark.read.parquet(data_path).select(scalar))

@pytest.mark.parametrize('data_gen', [boolean_gen, byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, timestamp_gen, date_gen], ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_null_literal(spark_tmp_path, data_gen):
# Write data to Parquet so Spark generates a plan using just the count of the data.
data_path = spark_tmp_path + '/AST_TEST_DATA'
Expand Down Expand Up @@ -235,7 +232,6 @@ def test_expm1(data_descr):
assert_unary_ast(data_descr, lambda df: df.selectExpr('expm1(a)'))

@pytest.mark.parametrize('data_descr', ast_comparable_descrs, ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_eq(data_descr):
(s1, s2) = with_cpu_session(lambda spark: gen_scalars(data_descr[0], 2))
assert_binary_ast(data_descr,
Expand All @@ -245,7 +241,6 @@ def test_eq(data_descr):
f.col('a') == f.col('b')))

@pytest.mark.parametrize('data_descr', ast_comparable_descrs, ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_ne(data_descr):
(s1, s2) = with_cpu_session(lambda spark: gen_scalars(data_descr[0], 2))
assert_binary_ast(data_descr,
Expand All @@ -255,7 +250,6 @@ def test_ne(data_descr):
f.col('a') != f.col('b')))

@pytest.mark.parametrize('data_descr', ast_comparable_descrs, ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_lt(data_descr):
(s1, s2) = with_cpu_session(lambda spark: gen_scalars(data_descr[0], 2))
assert_binary_ast(data_descr,
Expand All @@ -265,7 +259,6 @@ def test_lt(data_descr):
f.col('a') < f.col('b')))

@pytest.mark.parametrize('data_descr', ast_comparable_descrs, ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_lte(data_descr):
(s1, s2) = with_cpu_session(lambda spark: gen_scalars(data_descr[0], 2))
assert_binary_ast(data_descr,
Expand All @@ -275,7 +268,6 @@ def test_lte(data_descr):
f.col('a') <= f.col('b')))

@pytest.mark.parametrize('data_descr', ast_comparable_descrs, ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_gt(data_descr):
(s1, s2) = with_cpu_session(lambda spark: gen_scalars(data_descr[0], 2))
assert_binary_ast(data_descr,
Expand All @@ -285,7 +277,6 @@ def test_gt(data_descr):
f.col('a') > f.col('b')))

@pytest.mark.parametrize('data_descr', ast_comparable_descrs, ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_gte(data_descr):
(s1, s2) = with_cpu_session(lambda spark: gen_scalars(data_descr[0], 2))
assert_binary_ast(data_descr,
Expand Down
Loading