Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decimal Average Support #3898

Merged
merged 7 commits into from
Oct 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -14099,7 +14099,7 @@ are limited.
<td> </td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
<td><em>PS<br/>max DECIMAL precision of 23</em></td>
<td> </td>
<td> </td>
<td> </td>
Expand All @@ -14120,7 +14120,7 @@ are limited.
<td> </td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
<td>S</td>
<td> </td>
<td> </td>
<td> </td>
Expand All @@ -14142,7 +14142,7 @@ are limited.
<td> </td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
<td><em>PS<br/>max DECIMAL precision of 23</em></td>
<td> </td>
<td> </td>
<td> </td>
Expand All @@ -14163,7 +14163,7 @@ are limited.
<td> </td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
<td>S</td>
<td> </td>
<td> </td>
<td> </td>
Expand All @@ -14185,7 +14185,7 @@ are limited.
<td> </td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
<td><em>PS<br/>max DECIMAL precision of 23</em></td>
<td> </td>
<td> </td>
<td> </td>
Expand All @@ -14206,7 +14206,7 @@ are limited.
<td> </td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
<td>S</td>
<td> </td>
<td> </td>
<td> </td>
Expand Down
7 changes: 3 additions & 4 deletions integration_tests/src/main/python/hash_aggregate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,15 +296,14 @@ def test_hash_reduction_sum(data_gen, conf):
@approximate_float
@ignore_order
@incompat
@pytest.mark.parametrize('data_gen', _init_list_with_nans_and_no_nans, ids=idfn)
@pytest.mark.parametrize('data_gen', _init_list_with_nans_and_no_nans + [_grpkey_short_mid_decimals, _grpkey_short_big_decimals], ids=idfn)
@pytest.mark.parametrize('conf', get_params(_confs, params_markers_for_confs), ids=idfn)
def test_hash_grpby_avg(data_gen, conf):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: gen_df(spark, data_gen, length=100).groupby('a').agg(f.avg('b')),
conf=conf
)


# tracks https://github.com/NVIDIA/spark-rapids/issues/154
@approximate_float
@ignore_order
Expand Down Expand Up @@ -813,7 +812,7 @@ def test_hash_count_with_filter(data_gen, conf):
@approximate_float
@ignore_order
@incompat
@pytest.mark.parametrize('data_gen', _init_list_no_nans, ids=idfn)
@pytest.mark.parametrize('data_gen', _init_list_no_nans + [_grpkey_short_mid_decimals, _grpkey_short_big_decimals], ids=idfn)
@pytest.mark.parametrize('conf', get_params(_confs, params_markers_for_confs), ids=idfn)
def test_hash_multiple_filters(data_gen, conf):
assert_gpu_and_cpu_are_equal_sql(
Expand Down Expand Up @@ -997,7 +996,7 @@ def test_distinct_float_count_reductions(data_gen):
'count(DISTINCT a)'))

@approximate_float
@pytest.mark.parametrize('data_gen', numeric_gens, ids=idfn)
@pytest.mark.parametrize('data_gen', numeric_gens + [decimal_gen_12_2, decimal_gen_18_3, decimal_gen_20_2], ids=idfn)
def test_arithmetic_reductions(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
Expand Down
11 changes: 8 additions & 3 deletions integration_tests/src/main/python/window_function_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,21 @@
_grpkey_longs_with_decimals = [
('a', RepeatSeqGen(LongGen(nullable=False), length=20)),
('b', DecimalGen(precision=18, scale=3, nullable=False)),
('c', IntegerGen())]
('c', DecimalGen(precision=18, scale=3))]

_grpkey_longs_with_nullable_decimals = [
('a', RepeatSeqGen(LongGen(nullable=(True, 10.0)), length=20)),
('b', DecimalGen(precision=18, scale=10, nullable=True)),
('c', IntegerGen())]
('c', DecimalGen(precision=18, scale=10, nullable=True))]

_grpkey_longs_with_nullable_larger_decimals = [
('a', RepeatSeqGen(LongGen(nullable=(True, 10.0)), length=20)),
('b', DecimalGen(precision=23, scale=10, nullable=True)),
('c', DecimalGen(precision=23, scale=10, nullable=True))]

_grpkey_decimals_with_nulls = [
('a', RepeatSeqGen(LongGen(nullable=(True, 10.0)), length=20)),
('b', IntegerGen()),
# the max decimal precision supported by sum operation is 8
abellina marked this conversation as resolved.
Show resolved Hide resolved
('c', DecimalGen(precision=8, scale=3, nullable=True))]

_grpkey_byte_with_nulls = [
Expand Down Expand Up @@ -334,6 +338,7 @@ def test_window_aggs_for_range_numeric_date(data_gen, batch_size):
_grpkey_longs_with_nullable_dates,
_grpkey_longs_with_decimals,
_grpkey_longs_with_nullable_decimals,
_grpkey_longs_with_nullable_larger_decimals,
_grpkey_decimals_with_nulls], ids=idfn)
def test_window_aggs_for_rows(data_gen, batch_size):
conf = {'spark.rapids.sql.batchSizeBytes': batch_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,11 +216,15 @@ class Spark320Shims extends Spark32XShims {
GpuOverrides.expr[Average](
"Average aggregate operator",
ExprChecks.fullAgg(
TypeSig.DOUBLE, TypeSig.DOUBLE + TypeSig.DECIMAL_128_FULL,
// NullType is not technically allowed by Spark, but in practice in 3.2.0
// it can show up
Seq(ParamCheck("input", TypeSig.integral + TypeSig.fp + TypeSig.NULL,
TypeSig.numericAndInterval + TypeSig.NULL))),
// For Decimal Average the SUM adds a precision of 10 to avoid overflowing
// then it divides by the count with an output scale that is 4 more than the input
// scale. With how our divide works to match Spark, this means that we will need a
// precision of 5 more. So 38 - 10 - 5 = 23
abellina marked this conversation as resolved.
Show resolved Hide resolved
TypeSig.DOUBLE + TypeSig.DECIMAL_128_FULL,
TypeSig.DOUBLE + TypeSig.DECIMAL_128_FULL,
Seq(ParamCheck("input",
TypeSig.integral + TypeSig.fp + TypeSig.decimal(23),
TypeSig.numeric))),
(a, conf, p, r) => new AggExprMeta[Average](a, conf, p, r) {
override def tagAggForGpu(): Unit = {
val dataType = a.child.dataType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,15 @@ abstract class SparkBaseShims extends Spark30XShims {
GpuOverrides.expr[Average](
"Average aggregate operator",
ExprChecks.fullAgg(
TypeSig.DOUBLE, TypeSig.DOUBLE + TypeSig.DECIMAL_128_FULL,
Seq(ParamCheck("input", TypeSig.integral + TypeSig.fp, TypeSig.numeric))),
// For Decimal Average the SUM adds a precision of 10 to avoid overflowing
// then it divides by the count with an output scale that is 4 more than the input
// scale. With how our divide works to match Spark, this means that we will need a
// precision of 5 more. So 38 - 10 - 5 = 23
TypeSig.DOUBLE + TypeSig.DECIMAL_128_FULL,
TypeSig.DOUBLE + TypeSig.DECIMAL_128_FULL,
Seq(ParamCheck("input",
TypeSig.integral + TypeSig.fp + TypeSig.decimal(23),
TypeSig.numeric))),
(a, conf, p, r) => new AggExprMeta[Average](a, conf, p, r) {
override def tagAggForGpu(): Unit = {
val dataType = a.child.dataType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,15 +267,22 @@ abstract class SparkBaseShims extends Spark30XShims {
GpuOverrides.expr[Average](
"Average aggregate operator",
ExprChecks.fullAgg(
TypeSig.DOUBLE, TypeSig.DOUBLE + TypeSig.DECIMAL_128_FULL,
Seq(ParamCheck("input", TypeSig.integral + TypeSig.fp, TypeSig.numeric))),
// For Decimal Average the SUM adds a precision of 10 to avoid overflowing
// then it divides by the count with an output scale that is 4 more than the input
// scale. With how our divide works to match Spark, this means that we will need a
// precision of 5 more. So 38 - 10 - 5 = 23
TypeSig.DOUBLE + TypeSig.DECIMAL_128_FULL,
TypeSig.DOUBLE + TypeSig.DECIMAL_128_FULL,
Seq(ParamCheck("input",
TypeSig.integral + TypeSig.fp + TypeSig.decimal(23),
TypeSig.numeric))),
(a, conf, p, r) => new AggExprMeta[Average](a, conf, p, r) {
override def tagAggForGpu(): Unit = {
val dataType = a.child.dataType
GpuOverrides.checkAndTagFloatAgg(dataType, conf, this)
}

override def convertToGpu(childExprs: Seq[Expression]): GpuExpression =
override def convertToGpu(childExprs: Seq[Expression]): GpuExpression =
GpuAverage(childExprs.head)

// Average is not supported in ANSI mode right now, no matter the type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,15 @@ abstract class SparkBaseShims extends Spark30XShims {
GpuOverrides.expr[Average](
"Average aggregate operator",
ExprChecks.fullAgg(
TypeSig.DOUBLE, TypeSig.DOUBLE + TypeSig.DECIMAL_128_FULL,
Seq(ParamCheck("input", TypeSig.integral + TypeSig.fp, TypeSig.numeric))),
// For Decimal Average the SUM adds a precision of 10 to avoid overflowing
// then it divides by the count with an output scale that is 4 more than the input
// scale. With how our divide works to match Spark, this means that we will need a
// precision of 5 more. So 38 - 10 - 5 = 23
TypeSig.DOUBLE + TypeSig.DECIMAL_128_FULL,
TypeSig.DOUBLE + TypeSig.DECIMAL_128_FULL,
Seq(ParamCheck("input",
TypeSig.integral + TypeSig.fp + TypeSig.decimal(23),
TypeSig.numeric))),
(a, conf, p, r) => new AggExprMeta[Average](a, conf, p, r) {
override def tagAggForGpu(): Unit = {
val dataType = a.child.dataType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,15 @@ abstract class SparkBaseShims extends Spark31XShims {
GpuOverrides.expr[Average](
"Average aggregate operator",
ExprChecks.fullAgg(
TypeSig.DOUBLE, TypeSig.DOUBLE + TypeSig.DECIMAL_128_FULL,
Seq(ParamCheck("input", TypeSig.integral + TypeSig.fp, TypeSig.numeric))),
// For Decimal Average the SUM adds a precision of 10 to avoid overflowing
// then it divides by the count with an output scale that is 4 more than the input
// scale. With how our divide works to match Spark, this means that we will need a
// precision of 5 more. So 38 - 10 - 5 = 23
TypeSig.DOUBLE + TypeSig.DECIMAL_128_FULL,
TypeSig.DOUBLE + TypeSig.DECIMAL_128_FULL,
Seq(ParamCheck("input",
TypeSig.integral + TypeSig.fp + TypeSig.decimal(23),
TypeSig.numeric))),
(a, conf, p, r) => new AggExprMeta[Average](a, conf, p, r) {
override def tagAggForGpu(): Unit = {
val dataType = a.child.dataType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2055,7 +2055,8 @@ object GpuOverrides extends Logging {
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
a.dataType match {
case _: DecimalType =>
throw new IllegalStateException("Decimal Divide should be converted in CheckOverflow")
throw new IllegalStateException("Internal Error: Decimal Divide operations " +
"should be converted to the GPU in the CheckOverflow rule")
case _ =>
GpuDivide(lhs, rhs)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -319,9 +319,19 @@ object GpuWindowExec extends Arm {

exprs.foreach { expr =>
if (hasGpuWindowFunction(expr)) {
// First pass looks for GpuWindowFunctions and GpuWindowSpecDefinitions to build up
// First pass replace any operations that should be totally replaced.
val replacePass = expr.transformDown {
case GpuWindowExpression(
GpuAggregateExpression(rep: GpuReplaceWindowFunction, _, _, _, _), spec) =>
// We don't actually care about the GpuAggregateExpression because it is ignored
// by our GPU window operations anyways.
rep.windowReplacement(spec)
case GpuWindowExpression(rep: GpuReplaceWindowFunction, spec) =>
rep.windowReplacement(spec)
}
// Second pass looks for GpuWindowFunctions and GpuWindowSpecDefinitions to build up
// the preProject phase
val firstPass = expr.transformDown {
val secondPass = replacePass.transformDown {
case wf: GpuWindowFunction =>
// All window functions, including those that are also aggregation functions, are
// wrapped in a GpuWindowExpression, so dedup and save their children into the pre
Expand All @@ -340,14 +350,15 @@ object GpuWindowExec extends Arm {
}.toArray.toSeq
wsc.copy(partitionSpec = newPartitionSpec, orderSpec = newOrderSpec)
}
val secondPass = firstPass.transformDown {
// Final pass is to extract, dedup, and save the results.
val finalPass = secondPass.transformDown {
case we: GpuWindowExpression =>
// A window Expression holds a window function or an aggregate function, so put it into
// the windowOps phase, and create a new alias for it for the post phase
extractAndSave(we, windowOps, windowDedupe)
}.asInstanceOf[NamedExpression]

postProject += secondPass
postProject += finalPass
} else {
// There is no window function so pass the result through all of the phases (with deduping)
postProject += extractAndSave(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,22 @@ case class GpuSpecialFrameBoundary(boundary : SpecialFrameBoundary)
// Spark. This may expand in the future if other types of window functions show up.
trait GpuWindowFunction extends GpuUnevaluable with ShimExpression

/**
* This is a special window function that simply replaces itself with one or more
* window functions and other expressions that can be executed. This allows you to write
* `GpuAverage` in terms of `GpuSum` and `GpuCount` which can both operate on all window
* optimizations making `GpuAverage` be able to do the same.
*/
trait GpuReplaceWindowFunction extends GpuWindowFunction {
/**
* Return a new single expression that can replace the existing aggregation in window
* calculations. Please note that this requires that there are no nested window operations.
* For example you cannot do a SUM of AVERAGES with this currently. That support may be added
* in the future.
*/
def windowReplacement(spec: GpuWindowSpecDefinition): Expression
}

/**
* GPU Counterpart of `AggregateWindowFunction`.
* On the CPU this would extend `DeclarativeAggregate` and use the provided methods
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,9 @@ object TypeSig {
def psNote(dataType: TypeEnum.Value, note: String): TypeSig =
TypeSig.none.withPsNote(dataType, note)

def decimal(maxPrecision: Int): TypeSig =
new TypeSig(TypeEnum.ValueSet(TypeEnum.DECIMAL), maxPrecision)

/**
* All types nested and not nested
*/
Expand All @@ -558,16 +561,15 @@ object TypeSig {
val DATE: TypeSig = new TypeSig(TypeEnum.ValueSet(TypeEnum.DATE))
val TIMESTAMP: TypeSig = new TypeSig(TypeEnum.ValueSet(TypeEnum.TIMESTAMP))
val STRING: TypeSig = new TypeSig(TypeEnum.ValueSet(TypeEnum.STRING))
val DECIMAL_64: TypeSig =
new TypeSig(TypeEnum.ValueSet(TypeEnum.DECIMAL), DType.DECIMAL64_MAX_PRECISION)
val DECIMAL_64: TypeSig = decimal(DType.DECIMAL64_MAX_PRECISION)

/**
* Full support for 128 bit DECIMAL. In the future we expect to have other types with
* slightly less than full DECIMAL support. This are things like math operations where
* we cannot replicate the overflow behavior of Spark. These will be added when needed.
*/
val DECIMAL_128_FULL: TypeSig =
new TypeSig(TypeEnum.ValueSet(TypeEnum.DECIMAL), DecimalType.MAX_PRECISION)
val DECIMAL_128_FULL: TypeSig = decimal(DType.DECIMAL128_MAX_PRECISION)

val NULL: TypeSig = new TypeSig(TypeEnum.ValueSet(TypeEnum.NULL))
val BINARY: TypeSig = new TypeSig(TypeEnum.ValueSet(TypeEnum.BINARY))
val CALENDAR: TypeSig = new TypeSig(TypeEnum.ValueSet(TypeEnum.CALENDAR))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class GpuProjectExecMeta(
with Logging {
override def convertToGpu(): GpuExec = {
// Force list to avoid recursive Java serialization of lazy list Seq implementation
val gpuExprs = childExprs.map(_.convertToGpu()).toList
val gpuExprs = childExprs.map(_.convertToGpu().asInstanceOf[NamedExpression]).toList
val gpuChild = childPlans.head.convertIfNeeded()
if (conf.isProjectAstEnabled) {
if (childExprs.forall(_.canThisBeAst)) {
Expand Down Expand Up @@ -120,16 +120,14 @@ case class GpuProjectExec(
// using an Array, we opt in for List because it implements Seq while having non-recursive
// serde: https://github.com/scala/scala/blob/2.12.x/src/library/scala/collection/
// immutable/List.scala#L516
projectList: List[Expression],
projectList: List[NamedExpression],
child: SparkPlan
) extends ShimUnaryExecNode with GpuExec {

override lazy val additionalMetrics: Map[String, GpuMetric] = Map(
OP_TIME -> createNanoTimingMetric(MODERATE_LEVEL, DESCRIPTION_OP_TIME))

override def output: Seq[Attribute] = {
projectList.collect { case ne: NamedExpression => ne.toAttribute }
}
override def output: Seq[Attribute] = projectList.map(_.toAttribute)

override def outputOrdering: Seq[SortOrder] = child.outputOrdering

Expand Down
Loading