From 29db478acbc028dc26fa3a62f65a6c3c705b097b Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 29 Mar 2022 10:30:43 -0600 Subject: [PATCH] Support approx_percentile in reduction context (#4961) * Implement reduction for approximate_percentile / t-digest Signed-off-by: Andy Grove * fix regression * update generated docs --- docs/supported_ops.md | 391 +++++++----------- .../src/main/python/hash_aggregate_test.py | 40 +- .../rapids/GpuApproximatePercentile.scala | 16 +- .../nvidia/spark/rapids/GpuOverrides.scala | 2 +- .../rapids/ApproximatePercentileSuite.scala | 8 +- 5 files changed, 200 insertions(+), 257 deletions(-) diff --git a/docs/supported_ops.md b/docs/supported_ops.md index 16ed1d7ec45..a2577e8393d 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -14513,95 +14513,10 @@ are limited. UDT -ApproximatePercentile -`percentile_approx`, `approx_percentile` -Approximate percentile -This is not 100% compatible with the Spark version because the GPU implementation of approx_percentile is not bit-for-bit compatible with Apache Spark. To enable it, set spark.rapids.sql.incompatibleOps.enabled -reduction -input - -NS -NS -NS -NS -NS -NS -NS -NS - -NS - - - - - - - - - -percentage - - - - - - -NS - - - - - - - -NS - - - - - -accuracy - - - -NS - - - - - - - - - - - - - - - - -result - -NS -NS -NS -NS -NS -NS -NS -NS - -NS - - - -NS - - - - - +ApproximatePercentile +`percentile_approx`, `approx_percentile` +Approximate percentile +This is not 100% compatible with the Spark version because the GPU implementation of approx_percentile is not bit-for-bit compatible with Apache Spark. To enable it, set spark.rapids.sql.incompatibleOps.enabled aggregation input @@ -14687,19 +14602,19 @@ are limited. -window +reduction input -NS -NS -NS -NS -NS -NS +S +S +S +S +S +S NS NS -NS +S @@ -14716,7 +14631,7 @@ are limited. -NS +S @@ -14724,7 +14639,7 @@ are limited. -NS +S @@ -14734,7 +14649,7 @@ are limited. -NS +S @@ -14753,20 +14668,20 @@ are limited. result -NS -NS -NS -NS -NS -NS +S +S +S +S +S +S NS NS -NS +S -NS +PS
unsupported child types DATE, TIMESTAMP
@@ -14905,32 +14820,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - CollectList `collect_list` Collect a list of non-unique elements, not supported in reduction @@ -15064,6 +14953,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + CollectSet `collect_set` Collect a set of unique elements, not supported in reduction @@ -15330,32 +15245,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - First `first_value`, `first` first aggregate operator @@ -15489,6 +15378,32 @@ are limited. NS +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + Last `last`, `last_value` last aggregate operator @@ -15755,32 +15670,6 @@ are limited. NS -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - Min `min` Min aggregate operator @@ -15914,6 +15803,32 @@ are limited. NS +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + PivotFirst PivotFirst operator @@ -16179,32 +16094,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - StddevSamp `stddev_samp`, `std`, `stddev` Aggregation computing sample standard deviation @@ -16338,6 +16227,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + Sum `sum` Sum aggregate operator @@ -16604,32 +16519,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - VarianceSamp `var_samp`, `variance` Aggregation computing sample variance @@ -16763,6 +16652,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + NormalizeNaNAndZero Normalize NaN and zero diff --git a/integration_tests/src/main/python/hash_aggregate_test.py b/integration_tests/src/main/python/hash_aggregate_test.py index 9ce6c721d8f..4ab98da2eae 100644 --- a/integration_tests/src/main/python/hash_aggregate_test.py +++ b/integration_tests/src/main/python/hash_aggregate_test.py @@ -1271,6 +1271,14 @@ def do_it(spark): return df.groupBy('a').agg(f.min(df.b[1]["a"])) assert_gpu_and_cpu_are_equal_collect(do_it) +@incompat +@pytest.mark.parametrize('aqe_enabled', ['false', 'true'], ids=idfn) +def test_hash_groupby_approx_percentile_reduction(aqe_enabled): + conf = {'spark.sql.adaptive.enabled': aqe_enabled} + compare_percentile_approx( + lambda spark: gen_df(spark, [('v', DoubleGen())], length=100), + [0.05, 0.25, 0.5, 0.75, 0.95], conf, reduction = True) + @incompat @pytest.mark.parametrize('aqe_enabled', ['false', 'true'], ids=idfn) def test_hash_groupby_approx_percentile_byte(aqe_enabled): @@ -1405,11 +1413,11 @@ def test_hash_groupby_approx_percentile_decimal128_single(): # results due to the different algorithms being used. Instead we compute an exact percentile on the CPU and then # compute approximate percentiles on CPU and GPU and assert that the GPU numbers are accurate within some percentage # of the CPU numbers -def compare_percentile_approx(df_fun, percentiles, conf = {}): +def compare_percentile_approx(df_fun, percentiles, conf = {}, reduction = False): # create SQL statements for exact and approx percentiles - p_exact_sql = create_percentile_sql("percentile", percentiles) - p_approx_sql = create_percentile_sql("approx_percentile", percentiles) + p_exact_sql = create_percentile_sql("percentile", percentiles, reduction) + p_approx_sql = create_percentile_sql("approx_percentile", percentiles, reduction) def run_exact(spark): df = df_fun(spark) @@ -1436,8 +1444,9 @@ def run_approx(spark): gpu_approx_result = approx_gpu[i] # assert that keys match - assert cpu_exact_result['k'] == cpu_approx_result['k'] - assert cpu_exact_result['k'] == gpu_approx_result['k'] + if not reduction: + assert cpu_exact_result['k'] == cpu_approx_result['k'] + assert cpu_exact_result['k'] == gpu_approx_result['k'] # extract the percentile result column exact_percentile = cpu_exact_result['the_percentile'] @@ -1472,13 +1481,22 @@ def run_approx(spark): else: assert abs(cpu_delta / gpu_delta) - 1 < 0.001 -def create_percentile_sql(func_name, percentiles): - if isinstance(percentiles, list): - return """select k, {}(v, array({})) as the_percentile from t group by k order by k""".format( - func_name, ",".join(str(i) for i in percentiles)) +def create_percentile_sql(func_name, percentiles, reduction): + if reduction: + if isinstance(percentiles, list): + return """select {}(v, array({})) as the_percentile from t""".format( + func_name, ",".join(str(i) for i in percentiles)) + else: + return """select {}(v, {}) as the_percentile from t""".format( + func_name, percentiles) else: - return """select k, {}(v, {}) as the_percentile from t group by k order by k""".format( - func_name, percentiles) + if isinstance(percentiles, list): + return """select k, {}(v, array({})) as the_percentile from t group by k order by k""".format( + func_name, ",".join(str(i) for i in percentiles)) + else: + return """select k, {}(v, {}) as the_percentile from t group by k order by k""".format( + func_name, percentiles) + @ignore_order @pytest.mark.parametrize('data_gen', [_grpkey_strings_with_extra_nulls], ids=idfn) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuApproximatePercentile.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuApproximatePercentile.scala index f70068cfb39..fae1d8335c1 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuApproximatePercentile.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuApproximatePercentile.scala @@ -17,7 +17,7 @@ package com.nvidia.spark.rapids import ai.rapids.cudf -import ai.rapids.cudf.GroupByAggregation +import ai.rapids.cudf.{DType, GroupByAggregation, ReductionAggregation} import com.nvidia.spark.rapids.GpuCast.doCast import com.nvidia.spark.rapids.shims.ShimExpression @@ -178,8 +178,12 @@ case class ApproxPercentileFromTDigestExpr( class CudfTDigestUpdate(accuracyExpression: GpuLiteral) extends CudfAggregate { - override lazy val reductionAggregate: cudf.ColumnVector => cudf.Scalar = _ => - throw new UnsupportedOperationException("TDigest is not yet supported in reduction") + + override lazy val reductionAggregate: cudf.ColumnVector => cudf.Scalar = + (col: cudf.ColumnVector) => + col.reduce(ReductionAggregation.createTDigest(CudfTDigest.accuracy(accuracyExpression)), + DType.STRUCT) + override lazy val groupByAggregate: GroupByAggregation = GroupByAggregation.createTDigest(CudfTDigest.accuracy(accuracyExpression)) override val name: String = "CudfTDigestUpdate" @@ -189,8 +193,10 @@ class CudfTDigestUpdate(accuracyExpression: GpuLiteral) class CudfTDigestMerge(accuracyExpression: GpuLiteral) extends CudfAggregate { - override lazy val reductionAggregate: cudf.ColumnVector => cudf.Scalar = _ => - throw new UnsupportedOperationException("TDigest is not yet supported in reduction") + override lazy val reductionAggregate: cudf.ColumnVector => cudf.Scalar = + (col: cudf.ColumnVector) => + col.reduce(ReductionAggregation.mergeTDigest(CudfTDigest.accuracy(accuracyExpression))) + override lazy val groupByAggregate: GroupByAggregation = GroupByAggregation.mergeTDigest(CudfTDigest.accuracy(accuracyExpression)) override val name: String = "CudfTDigestMerge" diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index e437867f635..dfc3b2e4e8f 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -3308,7 +3308,7 @@ object GpuOverrides extends Logging { }), expr[ApproximatePercentile]( "Approximate percentile", - ExprChecks.groupByOnly( + ExprChecks.reductionAndGroupByAgg( // note that output can be single number or array depending on whether percentiles param // is a single number or an array TypeSig.gpuNumeric + diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/ApproximatePercentileSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/ApproximatePercentileSuite.scala index c4837252816..8f3dc2305ce 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/ApproximatePercentileSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/ApproximatePercentileSuite.scala @@ -94,8 +94,12 @@ class ApproximatePercentileSuite extends SparkQueryCompareTestSuite { "FROM salaries GROUP BY dept") } - test("fall back to CPU for reduction") { - sqlFallbackTest("SELECT approx_percentile(salary, array(0.5)) FROM salaries") + testSparkResultsAreEqual("approx percentile reduction", + df => salaries(df, DataTypes.DoubleType, 100), + maxFloatDiff = 25.0, // approx percentile on GPU uses a different algorithm to Spark + incompat = true) { df => + df.createOrReplaceTempView("salaries") + df.sparkSession.sql("SELECT approx_percentile(salary, array(0.5)) FROM salaries") } def sqlFallbackTest(sql: String) {