Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes a bug where datagen seed overrides were sticky and adds datagen_seed_override_disabled #10109

Merged
12 changes: 12 additions & 0 deletions integration_tests/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,18 @@ integration tests. For example:
$ DATAGEN_SEED=1702166057 SPARK_HOME=~/spark-3.4.0-bin-hadoop3 integration_tests/run_pyspark_from_build.sh
```

Tests can override the seed used using the test marker:

```
@datagen_overrides(seed=<new seed here>, [condition=True|False], [permanent=True|False])`.
NVnavkumar marked this conversation as resolved.
Show resolved Hide resolved
```

This marker has the following arguments:
- `seed`: a hard coded datagen seed to use.
- `condition`: is used to gate when the override is appropriate, usually used to say that specific shims
need the special override.
- `permanent`: forces a test to ignore `DATAGEN_SEED` if True. If False, or if absent, the `DATAGEN_SEED` value always wins.

### Running with non-UTC time zone
For the new added cases, we should check non-UTC time zone is working, or the non-UTC nightly CIs will fail.
The non-UTC nightly CIs are verifing all cases with non-UTC time zone.
Expand Down
71 changes: 52 additions & 19 deletions integration_tests/src/main/python/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,11 @@ def get_inject_oom_conf():
# For datagen: we expect a seed to be provided by the environment, or default to 0.
# Note that tests can override their seed when calling into datagen by setting seed= in their tests.
_test_datagen_random_seed = int(os.getenv("SPARK_RAPIDS_TEST_DATAGEN_SEED", 0))
print(f"Starting with datagen test seed: {_test_datagen_random_seed}. "
"Set env variable SPARK_RAPIDS_TEST_DATAGEN_SEED to override.")
_test_datagen_random_seed_user_provided = os.getenv("DATAGEN_SEED") is not None
provided_by_msg = "Provided by user with DATAGEN_SEED" if _test_datagen_random_seed_user_provided else "Automatically set"
_test_datagen_random_seed_init = _test_datagen_random_seed
print(f"Starting with datagen test seed: {_test_datagen_random_seed_init} ({provided_by_msg}). "
"Set env variable DATAGEN_SEED to override.")

def get_datagen_seed():
return _test_datagen_random_seed
Expand All @@ -178,16 +181,7 @@ def pytest_runtest_setup(item):
global _test_datagen_random_seed
_inject_oom = item.get_closest_marker('inject_oom')
datagen_overrides = item.get_closest_marker('datagen_overrides')
if datagen_overrides:
try:
seed = datagen_overrides.kwargs["seed"]
except KeyError:
raise Exception("datagen_overrides requires an override seed value")

override_seed = datagen_overrides.kwargs.get('condition', True)
if override_seed:
_test_datagen_random_seed = seed

_test_datagen_random_seed, _ = get_effective_seed(item, datagen_overrides)
order = item.get_closest_marker('ignore_order')
if order:
if order.kwargs.get('local', False):
Expand Down Expand Up @@ -307,6 +301,37 @@ def pytest_configure(config):
print(f"Starting with OOM injection seed: {oom_random_injection_seed}. "
"Set env variable SPARK_RAPIDS_TEST_INJECT_OOM_SEED to override.")

# Returns a tuple (seed, permanent) with the seed that test `item` should use given a
# possibly defined `datagen_overrides`, and if the seed choice is due to an override,
# whether that override is marked as `permanent`
def get_effective_seed(item, datagen_overrides):
if datagen_overrides:
is_permanent = False
# if the override is marked as permanent it will always override its seed
# else, if the user provides a seed via DATAGEN_SEED, we will override.
try:
is_permanent = datagen_overrides.kwargs["permanent"]
except KeyError:
pass
abellina marked this conversation as resolved.
Show resolved Hide resolved

override_condition = datagen_overrides.kwargs.get('condition', True)
do_override = (
# if the override condition is satisfied, we consider it
override_condition and (
# if the override is permanent, we always override
# if it is not permanent, we consider it only if the user didn't
# set DATAGEN_SEED
is_permanent or not _test_datagen_random_seed_user_provided))

if do_override:
try:
seed = datagen_overrides.kwargs["seed"]
except KeyError:
raise Exception("datagen_overrides requires an override seed value")
return (seed, is_permanent)

return (_test_datagen_random_seed_init, False)

def pytest_collection_modifyitems(config, items):
r = random.Random(oom_random_injection_seed)
for item in items:
gerashegalov marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -318,14 +343,22 @@ def pytest_collection_modifyitems(config, items):
injection_conf = injection_mode_and_conf[1] if len(injection_mode_and_conf) == 2 else None
inject_choice = False
datagen_overrides = item.get_closest_marker('datagen_overrides')
test_datagen_random_seed_choice, is_permanent = get_effective_seed(item, datagen_overrides)
qualifier = ""
if datagen_overrides:
test_datagen_random_seed_choice = datagen_overrides.kwargs.get('seed', _test_datagen_random_seed)
if test_datagen_random_seed_choice != _test_datagen_random_seed:
extras.append('DATAGEN_SEED_OVERRIDE=%s' % str(test_datagen_random_seed_choice))
else:
extras.append('DATAGEN_SEED=%s' % str(test_datagen_random_seed_choice))
else:
extras.append('DATAGEN_SEED=%s' % str(_test_datagen_random_seed))
is_override = test_datagen_random_seed_choice != _test_datagen_random_seed_init
qual_list = []
# i.e. a @datagen_overrides(seed=x, permanent=True) would see:
# DATAGEN_SEED_OVERRIDE_PERMANENT=x, and if it's not permanent
# it would just be tagged as DATAGEN_SEED_OVERRIDE=x
if is_override:
qual_list += ["OVERRIDE"]
if is_permanent:
qual_list += ["PERMANENT"]
qualifier = "_".join(qual_list)
if len(qualifier) != 0:
qualifier = "_" + qualifier # prefix separator for formatting purposes
extras.append('DATAGEN_SEED%s=%s' % (qualifier, str(test_datagen_random_seed_choice)))

if injection_mode == 'random':
inject_choice = r.randrange(0, 2) == 1
Expand Down
4 changes: 2 additions & 2 deletions integration_tests/src/main/python/delta_lake_delete_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def generate_dest_data(spark):
@pytest.mark.parametrize("use_cdf", [True, False], ids=idfn)
@pytest.mark.parametrize("partition_columns", [None, ["a"]], ids=idfn)
@pytest.mark.skipif(is_before_spark_320(), reason="Delta Lake writes are not supported before Spark 3.2.x")
@datagen_overrides(seed=0, reason='https://github.com/NVIDIA/spark-rapids/issues/9884')
@datagen_overrides(seed=0, permanent=True, reason='https://github.com/NVIDIA/spark-rapids/issues/9884')
def test_delta_delete_rows(spark_tmp_path, use_cdf, partition_columns):
# Databricks changes the number of files being written, so we cannot compare logs unless there's only one slice
num_slices_to_test = 1 if is_databricks_runtime() else 10
Expand All @@ -172,7 +172,7 @@ def generate_dest_data(spark):
@pytest.mark.parametrize("use_cdf", [True, False], ids=idfn)
@pytest.mark.parametrize("partition_columns", [None, ["a"]], ids=idfn)
@pytest.mark.skipif(is_before_spark_320(), reason="Delta Lake writes are not supported before Spark 3.2.x")
@datagen_overrides(seed=0, reason='https://github.com/NVIDIA/spark-rapids/issues/9884')
@datagen_overrides(seed=0, permanent=True, reason='https://github.com/NVIDIA/spark-rapids/issues/9884')
def test_delta_delete_dataframe_api(spark_tmp_path, use_cdf, partition_columns):
from delta.tables import DeltaTable
data_path = spark_tmp_path + "/DELTA_DATA"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def generate_dest_data(spark):
@pytest.mark.parametrize("use_cdf", [True, False], ids=idfn)
@pytest.mark.parametrize("partition_columns", [None, ["a"]], ids=idfn)
@pytest.mark.skipif(is_before_spark_320(), reason="Delta Lake writes are not supported before Spark 3.2.x")
@datagen_overrides(seed=0, reason='https://github.com/NVIDIA/spark-rapids/issues/9884')
@datagen_overrides(seed=0, permanent=True, reason='https://github.com/NVIDIA/spark-rapids/issues/9884')
def test_delta_update_rows(spark_tmp_path, use_cdf, partition_columns):
# Databricks changes the number of files being written, so we cannot compare logs unless there's only one slice
num_slices_to_test = 1 if is_databricks_runtime() else 10
Expand Down