Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable window-group-limit optimization on [databricks] #10550

Merged
merged 5 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions integration_tests/src/main/python/spark_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,9 @@ def is_databricks113_or_later():
def is_databricks122_or_later():
return is_databricks_version_or_later(12, 2)

def is_databricks133_or_later():
return is_databricks_version_or_later(13, 3)

def supports_delta_lake_deletion_vectors():
if is_databricks_runtime():
return is_databricks122_or_later()
Expand Down
14 changes: 8 additions & 6 deletions integration_tests/src/main/python/window_function_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pyspark.sql.types import DateType, TimestampType, NumericType
from pyspark.sql.window import Window
import pyspark.sql.functions as f
from spark_session import is_before_spark_320, is_before_spark_350, is_databricks113_or_later, spark_version, with_cpu_session
from spark_session import is_before_spark_320, is_databricks113_or_later, is_databricks133_or_later, is_spark_350_or_later, spark_version, with_cpu_session
import warnings

_grpkey_longs_with_no_nulls = [
Expand Down Expand Up @@ -2042,8 +2042,9 @@ def assert_query_runs_on(exec, conf):
assert_query_runs_on(exec='GpuBatchedBoundedWindowExec', conf=conf_200)


@pytest.mark.skipif(condition=is_before_spark_350(),
reason="WindowGroupLimit not available for spark.version < 3.5")
@pytest.mark.skipif(condition=not (is_spark_350_or_later() or is_databricks133_or_later()),
reason="WindowGroupLimit not available for spark.version < 3.5 "
"and Databricks version < 13.3")
@ignore_order(local=True)
@approximate_float
@pytest.mark.parametrize('batch_size', ['1k', '1g'], ids=idfn)
Expand Down Expand Up @@ -2087,12 +2088,13 @@ def test_window_group_limits_for_ranking_functions(data_gen, batch_size, rank_cl
lambda spark: gen_df(spark, data_gen, length=4096),
"window_agg_table",
query,
conf = conf)
conf=conf)


@allow_non_gpu('WindowGroupLimitExec')
@pytest.mark.skipif(condition=is_before_spark_350(),
reason="WindowGroupLimit not available for spark.version < 3.5")
@pytest.mark.skipif(condition=not (is_spark_350_or_later() or is_databricks133_or_later()),
reason="WindowGroupLimit not available for spark.version < 3.5 "
" and Databricks version < 13.3")
@ignore_order(local=True)
@approximate_float
def test_window_group_limits_fallback_for_row_number():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

/*** spark-rapids-shim-json-lines
{"spark": "341db"}
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
{"spark": "350"}
{"spark": "351"}
spark-rapids-shim-json-lines ***/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive._
import org.apache.spark.sql.execution.exchange.ENSURE_REQUIREMENTS
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec}
import org.apache.spark.sql.execution.window.WindowGroupLimitExec
import org.apache.spark.sql.rapids.GpuV1WriteUtils.GpuEmpty2Null
import org.apache.spark.sql.rapids.execution.python.GpuPythonUDAF
import org.apache.spark.sql.types.StringType
Expand Down Expand Up @@ -167,7 +168,15 @@ trait Spark341PlusDBShims extends Spark332PlusDBShims {
}
).disabledByDefault("Collect Limit replacement can be slower on the GPU, if huge number " +
"of rows in a batch it could help by limiting the number of rows transferred from " +
"GPU to CPU")
"GPU to CPU"),
GpuOverrides.exec[WindowGroupLimitExec](
"Apply group-limits for row groups destined for rank-based window functions like " +
"row_number(), rank(), and dense_rank()",
ExecChecks( // Similar to WindowExec.
(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 +
TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP).nested(),
TypeSig.all),
(limit, conf, p, r) => new GpuWindowGroupLimitExecMeta(limit, conf, p, r))
).map(r => (r.getClassFor.asSubclass(classOf[SparkPlan]), r)).toMap

override def getExecs: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] =
Expand Down
Loading