diff --git a/integration_tests/src/main/python/window_function_test.py b/integration_tests/src/main/python/window_function_test.py index a0106ab9ec6e..f83fb0e725f2 100644 --- a/integration_tests/src/main/python/window_function_test.py +++ b/integration_tests/src/main/python/window_function_test.py @@ -21,7 +21,7 @@ from pyspark.sql.types import NumericType from pyspark.sql.window import Window import pyspark.sql.functions as f -from spark_session import is_before_spark_320, is_before_spark_340, is_databricks113_or_later +from spark_session import is_before_spark_320, is_databricks113_or_later import warnings _grpkey_longs_with_no_nulls = [ @@ -74,6 +74,16 @@ ('b', DecimalGen(precision=38, scale=2, nullable=True)), ('c', DecimalGen(precision=38, scale=2, nullable=True))] +_grpkey_longs_with_nullable_floats = [ + ('a', RepeatSeqGen(LongGen(nullable=(True, 10.0)), length=20)), + ('b', FloatGen(nullable=True)), + ('c', IntegerGen(nullable=True))] + +_grpkey_longs_with_nullable_doubles = [ + ('a', RepeatSeqGen(LongGen(nullable=(True, 10.0)), length=20)), + ('b', DoubleGen(nullable=True)), + ('c', IntegerGen(nullable=True))] + _grpkey_decimals_with_nulls = [ ('a', RepeatSeqGen(LongGen(nullable=(True, 10.0)), length=20)), ('b', IntegerGen()), @@ -879,15 +889,17 @@ def test_window_aggs_for_ranges_timestamps(data_gen): pytest.param(_grpkey_longs_with_nullable_largest_decimals, marks=pytest.mark.xfail( condition=is_databricks113_or_later(), - reason='https://github.com/NVIDIA/spark-rapids/issues/7429')) + reason='https://github.com/NVIDIA/spark-rapids/issues/7429')), + _grpkey_longs_with_nullable_floats, + _grpkey_longs_with_nullable_doubles ], ids=idfn) -def test_window_aggregations_for_decimal_ranges(data_gen): +def test_window_aggregations_for_decimal_and_float_ranges(data_gen): """ - Tests for range window aggregations, with DECIMAL order by columns. + Tests for range window aggregations, with DECIMAL/FLOATING POINT order by columns. The table schema used: a: Group By column - b: Order By column (decimal) - c: Aggregation column (incidentally, also decimal) + b: Order By column (decimals, floats, doubles) + c: Aggregation column (decimals or ints) Since this test is for the order-by column type, and not for each specific windowing aggregation, we use COUNT(1) throughout the test, for different window widths and ordering. diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 2135bab7245f..744c5844a751 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -943,10 +943,12 @@ object GpuOverrides extends Logging { TypeSig.numericAndInterval, Seq( ParamCheck("lower", - TypeSig.CALENDAR + TypeSig.NULL + TypeSig.integral + TypeSig.DECIMAL_128, + TypeSig.CALENDAR + TypeSig.NULL + TypeSig.integral + TypeSig.DECIMAL_128 + + TypeSig.FLOAT + TypeSig.DOUBLE, TypeSig.numericAndInterval), ParamCheck("upper", - TypeSig.CALENDAR + TypeSig.NULL + TypeSig.integral + TypeSig.DECIMAL_128, + TypeSig.CALENDAR + TypeSig.NULL + TypeSig.integral + TypeSig.DECIMAL_128 + + TypeSig.FLOAT + TypeSig.DOUBLE, TypeSig.numericAndInterval))), (windowFrame, conf, p, r) => new GpuSpecifiedWindowFrameMeta(windowFrame, conf, p, r) ), expr[WindowSpecDefinition]( diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExec.scala index e67e0a25e62d..2d50c5a66419 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExec.scala @@ -636,7 +636,47 @@ case class BoundGpuWindowFunction( val dataType: DataType = windowFunc.dataType } -case class ParsedBoundary(isUnbounded: Boolean, value: Either[BigInt, Long]) +/** + * Abstraction for possible range-boundary specifications. + * + * This provides type disjunction for Long, BigInt and Double, + * the three types that might represent a range boundary. + */ +abstract class RangeBoundaryValue { + def long: Long = RangeBoundaryValue.long(this) + def bigInt: BigInt = RangeBoundaryValue.bigInt(this) + def double: Double = RangeBoundaryValue.double(this) +} + +case class LongRangeBoundaryValue(value: Long) extends RangeBoundaryValue +case class BigIntRangeBoundaryValue(value: BigInt) extends RangeBoundaryValue +case class DoubleRangeBoundaryValue(value: Double) extends RangeBoundaryValue + +object RangeBoundaryValue { + + def long(boundary: RangeBoundaryValue): Long = boundary match { + case LongRangeBoundaryValue(l) => l + case other => throw new NoSuchElementException(s"Cannot get `long` from $other") + } + + def bigInt(boundary: RangeBoundaryValue): BigInt = boundary match { + case BigIntRangeBoundaryValue(b) => b + case other => throw new NoSuchElementException(s"Cannot get `bigInt` from $other") + } + + def double(boundary: RangeBoundaryValue): Double = boundary match { + case DoubleRangeBoundaryValue(d) => d + case other => throw new NoSuchElementException(s"Cannot get `double` from $other") + } + + def long(value: Long): LongRangeBoundaryValue = LongRangeBoundaryValue(value) + + def bigInt(value: BigInt): BigIntRangeBoundaryValue = BigIntRangeBoundaryValue(value) + + def double(value: Double): DoubleRangeBoundaryValue = DoubleRangeBoundaryValue(value) +} + +case class ParsedBoundary(isUnbounded: Boolean, value: RangeBoundaryValue) object GroupedAggregations { /** @@ -754,22 +794,23 @@ object GroupedAggregations { if (bound.isUnbounded) { None } else { - val valueLong = bound.value.right // Used for all cases except DECIMAL128. val s = orderByType match { - case DType.INT8 => Scalar.fromByte(valueLong.get.toByte) - case DType.INT16 => Scalar.fromShort(valueLong.get.toShort) - case DType.INT32 => Scalar.fromInt(valueLong.get.toInt) - case DType.INT64 => Scalar.fromLong(valueLong.get) + case DType.INT8 => Scalar.fromByte(bound.value.long.toByte) + case DType.INT16 => Scalar.fromShort(bound.value.long.toShort) + case DType.INT32 => Scalar.fromInt(bound.value.long.toInt) + case DType.INT64 => Scalar.fromLong(bound.value.long) + case DType.FLOAT32 => Scalar.fromFloat(bound.value.double.toFloat) + case DType.FLOAT64 => Scalar.fromDouble(bound.value.double) // Interval is not working for DateType - case DType.TIMESTAMP_DAYS => Scalar.durationFromLong(DType.DURATION_DAYS, valueLong.get) + case DType.TIMESTAMP_DAYS => Scalar.durationFromLong(DType.DURATION_DAYS, bound.value.long) case DType.TIMESTAMP_MICROSECONDS => - Scalar.durationFromLong(DType.DURATION_MICROSECONDS, valueLong.get) + Scalar.durationFromLong(DType.DURATION_MICROSECONDS, bound.value.long) case x if x.getTypeId == DType.DTypeEnum.DECIMAL32 => - Scalar.fromDecimal(x.getScale, valueLong.get.toInt) + Scalar.fromDecimal(x.getScale, bound.value.long.toInt) case x if x.getTypeId == DType.DTypeEnum.DECIMAL64 => - Scalar.fromDecimal(x.getScale, valueLong.get) + Scalar.fromDecimal(x.getScale, bound.value.long) case x if x.getTypeId == DType.DTypeEnum.DECIMAL128 => - Scalar.fromDecimal(x.getScale, bound.value.left.get.underlying()) + Scalar.fromDecimal(x.getScale, bound.value.bigInt.underlying()) case x if x.getTypeId == DType.DTypeEnum.STRING => // Not UNBOUNDED. The only other supported boundary for String is CURRENT ROW, i.e. 0. Scalar.fromString("") @@ -782,36 +823,52 @@ object GroupedAggregations { private def getRangeBoundaryValue(boundary: Expression, orderByType: DType): ParsedBoundary = boundary match { case special: GpuSpecialFrameBoundary => - val isUnBounded = special.isUnbounded - val isDecimal128 = orderByType.getTypeId == DType.DTypeEnum.DECIMAL128 - ParsedBoundary(isUnBounded, if (isDecimal128) Left(special.value) else Right(special.value)) + ParsedBoundary( + isUnbounded = special.isUnbounded, + value = orderByType.getTypeId match { + case DType.DTypeEnum.DECIMAL128 => RangeBoundaryValue.bigInt(special.value) + case DType.DTypeEnum.FLOAT32 | DType.DTypeEnum.FLOAT64 => + RangeBoundaryValue.double(special.value) + case _ => RangeBoundaryValue.long(special.value) + } + ) case GpuLiteral(ci: CalendarInterval, CalendarIntervalType) => // Get the total microseconds for TIMESTAMP_MICROSECONDS var x = TimeUnit.DAYS.toMicros(ci.days) + ci.microseconds if (x == Long.MinValue) x = Long.MaxValue - ParsedBoundary(isUnbounded = false, Right(Math.abs(x))) + ParsedBoundary(isUnbounded = false, RangeBoundaryValue.long(Math.abs(x))) case GpuLiteral(value, ByteType) => var x = value.asInstanceOf[Byte] if (x == Byte.MinValue) x = Byte.MaxValue - ParsedBoundary(isUnbounded = false, Right(Math.abs(x))) + ParsedBoundary(isUnbounded = false, RangeBoundaryValue.long(Math.abs(x))) case GpuLiteral(value, ShortType) => var x = value.asInstanceOf[Short] if (x == Short.MinValue) x = Short.MaxValue - ParsedBoundary(isUnbounded = false, Right(Math.abs(x))) + ParsedBoundary(isUnbounded = false, RangeBoundaryValue.long(Math.abs(x))) case GpuLiteral(value, IntegerType) => var x = value.asInstanceOf[Int] if (x == Int.MinValue) x = Int.MaxValue - ParsedBoundary(isUnbounded = false, Right(Math.abs(x))) + ParsedBoundary(isUnbounded = false, RangeBoundaryValue.long(Math.abs(x))) case GpuLiteral(value, LongType) => var x = value.asInstanceOf[Long] if (x == Long.MinValue) x = Long.MaxValue - ParsedBoundary(isUnbounded = false, Right(Math.abs(x))) + ParsedBoundary(isUnbounded = false, RangeBoundaryValue.long(Math.abs(x))) + case GpuLiteral(value, FloatType) => + var x = value.asInstanceOf[Float] + if (x == Float.MinValue) x = Float.MaxValue + ParsedBoundary(isUnbounded = false, RangeBoundaryValue.double(Math.abs(x))) + case GpuLiteral(value, DoubleType) => + var x = value.asInstanceOf[Double] + if (x == Double.MinValue) x = Double.MaxValue + ParsedBoundary(isUnbounded = false, RangeBoundaryValue.double(Math.abs(x))) case GpuLiteral(value: Decimal, DecimalType()) => orderByType.getTypeId match { case DType.DTypeEnum.DECIMAL32 | DType.DTypeEnum.DECIMAL64 => - ParsedBoundary(isUnbounded = false, Right(Math.abs(value.toUnscaledLong))) + ParsedBoundary(isUnbounded = false, + RangeBoundaryValue.long(Math.abs(value.toUnscaledLong))) case DType.DTypeEnum.DECIMAL128 => - ParsedBoundary(isUnbounded = false, Left(value.toJavaBigDecimal.unscaledValue().abs)) + ParsedBoundary(isUnbounded = false, + RangeBoundaryValue.bigInt(value.toJavaBigDecimal.unscaledValue().abs)) case anythingElse => throw new UnsupportedOperationException(s"Unexpected Decimal type: $anythingElse") } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExpression.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExpression.scala index 56615a74c48b..7ffbdc807445 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExpression.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExpression.scala @@ -101,12 +101,12 @@ abstract class GpuWindowExpressionMetaBase( val orderSpec = wrapped.windowSpec.orderSpec if (orderSpec.length > 1) { // We only support a single order by column - willNotWorkOnGpu("only a single date/time or integral (Boolean exclusive)" + + willNotWorkOnGpu("only a single date/time or numeric (Boolean exclusive) " + "based column in window range functions is supported") } val orderByTypeSupported = orderSpec.forall { so => so.dataType match { - case ByteType | ShortType | IntegerType | LongType | + case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | DateType | TimestampType | StringType | DecimalType() => true case _ => false } @@ -134,6 +134,14 @@ abstract class GpuWindowExpressionMetaBase( s"Range window frame is not 100% compatible when the order by type is " + s"long and the range value calculated has overflow. " + s"To enable it please set ${RapidsConf.ENABLE_RANGE_WINDOW_LONG} to true.") + case FloatType => if (!conf.isRangeWindowFloatEnabled) willNotWorkOnGpu( + s"Range window frame is not 100% compatible when the order by type is " + + s"float and the range value calculated has overflow. " + + s"To enable it please set ${RapidsConf.ENABLE_RANGE_WINDOW_FLOAT} to true.") + case DoubleType => if (!conf.isRangeWindowDoubleEnabled) willNotWorkOnGpu( + s"Range window frame is not 100% compatible when the order by type is " + + s"double and the range value calculated has overflow. " + + s"To enable it please set ${RapidsConf.ENABLE_RANGE_WINDOW_DOUBLE} to true.") case DecimalType() => if (!conf.isRangeWindowDecimalEnabled) willNotWorkOnGpu( s"To enable DECIMAL order by columns with Range window frames, " + s"please set ${RapidsConf.ENABLE_RANGE_WINDOW_DECIMAL} to true.") @@ -144,7 +152,7 @@ abstract class GpuWindowExpressionMetaBase( // check whether the boundaries are supported or not. Seq(spec.lower, spec.upper).foreach { case l @ Literal(_, ByteType | ShortType | IntegerType | - LongType | DecimalType()) => + LongType | FloatType | DoubleType | DecimalType()) => checkRangeBoundaryConfig(l.dataType) case Literal(ci: CalendarInterval, CalendarIntervalType) => // interval is only working for TimeStampType @@ -356,7 +364,7 @@ abstract class GpuSpecifiedWindowFrameMetaBase( * Tag RangeFrame for other types and get the value */ def getAndTagOtherTypesForRangeFrame(bounds : Expression, isLower : Boolean): Long = { - willNotWorkOnGpu(s"Bounds for Range-based window frames must be specified in Integral" + + willNotWorkOnGpu(s"Bounds for Range-based window frames must be specified in numeric" + s" type (Boolean exclusive) or CalendarInterval. Found ${bounds.dataType}") if (isLower) -1 else 1 // not check again } @@ -377,36 +385,58 @@ abstract class GpuSpecifiedWindowFrameMetaBase( return None } - val value: BigInt = bounds match { - case Literal(value, ByteType) => value.asInstanceOf[Byte].toLong - case Literal(value, ShortType) => value.asInstanceOf[Short].toLong - case Literal(value, IntegerType) => value.asInstanceOf[Int].toLong - case Literal(value, LongType) => value.asInstanceOf[Long] - case Literal(value: Decimal, DecimalType()) => value.toJavaBigDecimal.unscaledValue() + /** + * Check bounds value relative to current row: + * 1. lower-bound should not be ahead of the current row. + * 2. upper-bound should not be behind the current row. + */ + def checkBounds[T](boundsValue: T) + (implicit ev: Numeric[T]): Option[String] = { + if (isLower && ev.compare(boundsValue, ev.zero) > 0) { + Some(s"Lower-bounds ahead of current row is not supported. Found: $boundsValue") + } + else if (!isLower && ev.compare(boundsValue, ev.zero) < 0) { + Some(s"Upper-bounds behind current row is not supported. Found: $boundsValue") + } + else { + None + } + } + + bounds match { + case Literal(value, ByteType) => + checkBounds(value.asInstanceOf[Byte].toLong) + case Literal(value, ShortType) => + checkBounds(value.asInstanceOf[Short].toLong) + case Literal(value, IntegerType) => + checkBounds(value.asInstanceOf[Int].toLong) + case Literal(value, LongType) => + checkBounds(value.asInstanceOf[Long]) + case Literal(value, FloatType) => + checkBounds(value.asInstanceOf[Float]) + case Literal(value, DoubleType) => + checkBounds(value.asInstanceOf[Double]) + case Literal(value: Decimal, DecimalType()) => + checkBounds(BigInt(value.toJavaBigDecimal.unscaledValue())) case Literal(ci: CalendarInterval, CalendarIntervalType) => if (ci.months != 0) { willNotWorkOnGpu("interval months isn't supported") } // return the total microseconds try { - Math.addExact( - Math.multiplyExact(ci.days.toLong, TimeUnit.DAYS.toMicros(1)), - ci.microseconds) + checkBounds( + Math.addExact( + Math.multiplyExact(ci.days.toLong, TimeUnit.DAYS.toMicros(1)), + ci.microseconds)) } catch { case _: ArithmeticException => willNotWorkOnGpu("windows over timestamps are converted to microseconds " + s"and $ci is too large to fit") - if (isLower) -1 else 1 // not check again + None } - case _ => getAndTagOtherTypesForRangeFrame(bounds, isLower) - } - - if (isLower && value > 0) { - Some(s"Lower-bounds ahead of current row is not supported. Found: $value") - } else if (!isLower && value < 0) { - Some(s"Upper-bounds behind current row is not supported. Found: $value") - } else { - None + case _ => + getAndTagOtherTypesForRangeFrame(bounds, isLower) + None } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala index 548dc5884b3b..e04fc7edcd67 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala @@ -1266,6 +1266,20 @@ object RapidsConf { .booleanConf .createWithDefault(true) + val ENABLE_RANGE_WINDOW_FLOAT: ConfEntryWithDefault[Boolean] = + conf("spark.rapids.sql.window.range.float.enabled") + .doc("When set to false, this disables the range window acceleration for the " + + "FLOAT type order-by column") + .booleanConf + .createWithDefault(true) + + val ENABLE_RANGE_WINDOW_DOUBLE: ConfEntryWithDefault[Boolean] = + conf("spark.rapids.sql.window.range.double.enabled") + .doc("When set to false, this disables the range window acceleration for the " + + "double type order-by column") + .booleanConf + .createWithDefault(true) + val ENABLE_RANGE_WINDOW_DECIMAL: ConfEntryWithDefault[Boolean] = conf("spark.rapids.sql.window.range.decimal.enabled") .doc("When set to false, this disables the range window acceleration for the " + @@ -2569,6 +2583,10 @@ class RapidsConf(conf: Map[String, String]) extends Logging { lazy val isRangeWindowLongEnabled: Boolean = get(ENABLE_RANGE_WINDOW_LONG) + lazy val isRangeWindowFloatEnabled: Boolean = get(ENABLE_RANGE_WINDOW_FLOAT) + + lazy val isRangeWindowDoubleEnabled: Boolean = get(ENABLE_RANGE_WINDOW_DOUBLE) + lazy val isRangeWindowDecimalEnabled: Boolean = get(ENABLE_RANGE_WINDOW_DECIMAL) lazy val allowSinglePassPartialSortAgg: Boolean = get(ENABLE_SINGLE_PASS_PARTIAL_SORT_AGG) diff --git a/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala b/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala index f1187b39a486..5f1b8576249a 100644 --- a/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala +++ b/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala @@ -227,11 +227,11 @@ trait Spark320PlusShims extends SparkShims with RebaseShims with Logging { Seq( ParamCheck("lower", TypeSig.CALENDAR + TypeSig.NULL + TypeSig.integral + TypeSig.DAYTIME - + TypeSig.DECIMAL_128, + + TypeSig.DECIMAL_128 + TypeSig.FLOAT + TypeSig.DOUBLE, TypeSig.numericAndInterval), ParamCheck("upper", TypeSig.CALENDAR + TypeSig.NULL + TypeSig.integral + TypeSig.DAYTIME - + TypeSig.DECIMAL_128, + + TypeSig.DECIMAL_128 + TypeSig.FLOAT + TypeSig.DOUBLE, TypeSig.numericAndInterval))), (windowFrame, conf, p, r) => new GpuSpecifiedWindowFrameMeta(windowFrame, conf, p, r)), GpuOverrides.expr[WindowExpression]( diff --git a/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/gpuWindows.scala b/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/gpuWindows.scala index f68e825c3ea9..76c5705f6b09 100644 --- a/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/gpuWindows.scala +++ b/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/gpuWindows.scala @@ -33,7 +33,7 @@ spark-rapids-shim-json-lines ***/ package com.nvidia.spark.rapids.shims -import com.nvidia.spark.rapids.{DataFromReplacementRule, GpuLiteral, GpuSpecifiedWindowFrameMetaBase, GpuWindowExpressionMetaBase, ParsedBoundary, RapidsConf, RapidsMeta} +import com.nvidia.spark.rapids.{DataFromReplacementRule, GpuLiteral, GpuSpecifiedWindowFrameMetaBase, GpuWindowExpressionMetaBase, ParsedBoundary, RangeBoundaryValue, RapidsConf, RapidsMeta} import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, SpecifiedWindowFrame, WindowExpression} import org.apache.spark.sql.rapids.shims.Spark32XShimsUtils @@ -85,7 +85,7 @@ object GpuWindowUtil { case GpuLiteral(value, _: DayTimeIntervalType) => var x = value.asInstanceOf[Long] if (x == Long.MinValue) x = Long.MaxValue - ParsedBoundary(isUnbounded = false, Right(Math.abs(x))) + ParsedBoundary(isUnbounded = false, RangeBoundaryValue.long(Math.abs(x))) case anything => throw new UnsupportedOperationException("Unsupported window frame" + s" expression $anything") }