Skip to content

Commit

Permalink
Support Float order-by columns for RANGE window functions
Browse files Browse the repository at this point in the history
Depends on rapidsai/cudf#13635.

This commit adds support for floating point order-by columns in
RANGE based window functions.

Prior to this commit, when the `GpuWindowExec` was presented with an
order-by column of floating-point type, the entire window operation
would fall back to CPU execution. This should now execute entirely
on the GPU.
  • Loading branch information
mythrocks committed Jun 29, 2023
1 parent 33c63fd commit 98b6528
Show file tree
Hide file tree
Showing 7 changed files with 175 additions and 56 deletions.
24 changes: 18 additions & 6 deletions integration_tests/src/main/python/window_function_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pyspark.sql.types import NumericType
from pyspark.sql.window import Window
import pyspark.sql.functions as f
from spark_session import is_before_spark_320, is_before_spark_340, is_databricks113_or_later
from spark_session import is_before_spark_320, is_databricks113_or_later
import warnings

_grpkey_longs_with_no_nulls = [
Expand Down Expand Up @@ -74,6 +74,16 @@
('b', DecimalGen(precision=38, scale=2, nullable=True)),
('c', DecimalGen(precision=38, scale=2, nullable=True))]

_grpkey_longs_with_nullable_floats = [
('a', RepeatSeqGen(LongGen(nullable=(True, 10.0)), length=20)),
('b', FloatGen(nullable=True)),
('c', IntegerGen(nullable=True))]

_grpkey_longs_with_nullable_doubles = [
('a', RepeatSeqGen(LongGen(nullable=(True, 10.0)), length=20)),
('b', DoubleGen(nullable=True)),
('c', IntegerGen(nullable=True))]

_grpkey_decimals_with_nulls = [
('a', RepeatSeqGen(LongGen(nullable=(True, 10.0)), length=20)),
('b', IntegerGen()),
Expand Down Expand Up @@ -879,15 +889,17 @@ def test_window_aggs_for_ranges_timestamps(data_gen):
pytest.param(_grpkey_longs_with_nullable_largest_decimals,
marks=pytest.mark.xfail(
condition=is_databricks113_or_later(),
reason='https://github.com/NVIDIA/spark-rapids/issues/7429'))
reason='https://github.com/NVIDIA/spark-rapids/issues/7429')),
_grpkey_longs_with_nullable_floats,
_grpkey_longs_with_nullable_doubles
], ids=idfn)
def test_window_aggregations_for_decimal_ranges(data_gen):
def test_window_aggregations_for_decimal_and_float_ranges(data_gen):
"""
Tests for range window aggregations, with DECIMAL order by columns.
Tests for range window aggregations, with DECIMAL/FLOATING POINT order by columns.
The table schema used:
a: Group By column
b: Order By column (decimal)
c: Aggregation column (incidentally, also decimal)
b: Order By column (decimals, floats, doubles)
c: Aggregation column (decimals or ints)
Since this test is for the order-by column type, and not for each specific windowing aggregation,
we use COUNT(1) throughout the test, for different window widths and ordering.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -943,10 +943,12 @@ object GpuOverrides extends Logging {
TypeSig.numericAndInterval,
Seq(
ParamCheck("lower",
TypeSig.CALENDAR + TypeSig.NULL + TypeSig.integral + TypeSig.DECIMAL_128,
TypeSig.CALENDAR + TypeSig.NULL + TypeSig.integral + TypeSig.DECIMAL_128 +
TypeSig.FLOAT + TypeSig.DOUBLE,
TypeSig.numericAndInterval),
ParamCheck("upper",
TypeSig.CALENDAR + TypeSig.NULL + TypeSig.integral + TypeSig.DECIMAL_128,
TypeSig.CALENDAR + TypeSig.NULL + TypeSig.integral + TypeSig.DECIMAL_128 +
TypeSig.FLOAT + TypeSig.DOUBLE,
TypeSig.numericAndInterval))),
(windowFrame, conf, p, r) => new GpuSpecifiedWindowFrameMeta(windowFrame, conf, p, r) ),
expr[WindowSpecDefinition](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,47 @@ case class BoundGpuWindowFunction(
val dataType: DataType = windowFunc.dataType
}

case class ParsedBoundary(isUnbounded: Boolean, value: Either[BigInt, Long])
/**
* Abstraction for possible range-boundary specifications.
*
* This provides type disjunction for Long, BigInt and Double,
* the three types that might represent a range boundary.
*/
abstract class RangeBoundaryValue {
def long: Long = RangeBoundaryValue.long(this)
def bigInt: BigInt = RangeBoundaryValue.bigInt(this)
def double: Double = RangeBoundaryValue.double(this)
}

case class LongRangeBoundaryValue(value: Long) extends RangeBoundaryValue
case class BigIntRangeBoundaryValue(value: BigInt) extends RangeBoundaryValue
case class DoubleRangeBoundaryValue(value: Double) extends RangeBoundaryValue

object RangeBoundaryValue {

def long(boundary: RangeBoundaryValue): Long = boundary match {
case LongRangeBoundaryValue(l) => l
case other => throw new NoSuchElementException(s"Cannot get `long` from $other")
}

def bigInt(boundary: RangeBoundaryValue): BigInt = boundary match {
case BigIntRangeBoundaryValue(b) => b
case other => throw new NoSuchElementException(s"Cannot get `bigInt` from $other")
}

def double(boundary: RangeBoundaryValue): Double = boundary match {
case DoubleRangeBoundaryValue(d) => d
case other => throw new NoSuchElementException(s"Cannot get `double` from $other")
}

def long(value: Long): LongRangeBoundaryValue = LongRangeBoundaryValue(value)

def bigInt(value: BigInt): BigIntRangeBoundaryValue = BigIntRangeBoundaryValue(value)

def double(value: Double): DoubleRangeBoundaryValue = DoubleRangeBoundaryValue(value)
}

case class ParsedBoundary(isUnbounded: Boolean, value: RangeBoundaryValue)

object GroupedAggregations {
/**
Expand Down Expand Up @@ -754,22 +794,23 @@ object GroupedAggregations {
if (bound.isUnbounded) {
None
} else {
val valueLong = bound.value.right // Used for all cases except DECIMAL128.
val s = orderByType match {
case DType.INT8 => Scalar.fromByte(valueLong.get.toByte)
case DType.INT16 => Scalar.fromShort(valueLong.get.toShort)
case DType.INT32 => Scalar.fromInt(valueLong.get.toInt)
case DType.INT64 => Scalar.fromLong(valueLong.get)
case DType.INT8 => Scalar.fromByte(bound.value.long.toByte)
case DType.INT16 => Scalar.fromShort(bound.value.long.toShort)
case DType.INT32 => Scalar.fromInt(bound.value.long.toInt)
case DType.INT64 => Scalar.fromLong(bound.value.long)
case DType.FLOAT32 => Scalar.fromFloat(bound.value.double.toFloat)
case DType.FLOAT64 => Scalar.fromDouble(bound.value.double)
// Interval is not working for DateType
case DType.TIMESTAMP_DAYS => Scalar.durationFromLong(DType.DURATION_DAYS, valueLong.get)
case DType.TIMESTAMP_DAYS => Scalar.durationFromLong(DType.DURATION_DAYS, bound.value.long)
case DType.TIMESTAMP_MICROSECONDS =>
Scalar.durationFromLong(DType.DURATION_MICROSECONDS, valueLong.get)
Scalar.durationFromLong(DType.DURATION_MICROSECONDS, bound.value.long)
case x if x.getTypeId == DType.DTypeEnum.DECIMAL32 =>
Scalar.fromDecimal(x.getScale, valueLong.get.toInt)
Scalar.fromDecimal(x.getScale, bound.value.long.toInt)
case x if x.getTypeId == DType.DTypeEnum.DECIMAL64 =>
Scalar.fromDecimal(x.getScale, valueLong.get)
Scalar.fromDecimal(x.getScale, bound.value.long)
case x if x.getTypeId == DType.DTypeEnum.DECIMAL128 =>
Scalar.fromDecimal(x.getScale, bound.value.left.get.underlying())
Scalar.fromDecimal(x.getScale, bound.value.bigInt.underlying())
case x if x.getTypeId == DType.DTypeEnum.STRING =>
// Not UNBOUNDED. The only other supported boundary for String is CURRENT ROW, i.e. 0.
Scalar.fromString("")
Expand All @@ -782,36 +823,52 @@ object GroupedAggregations {
private def getRangeBoundaryValue(boundary: Expression, orderByType: DType): ParsedBoundary =
boundary match {
case special: GpuSpecialFrameBoundary =>
val isUnBounded = special.isUnbounded
val isDecimal128 = orderByType.getTypeId == DType.DTypeEnum.DECIMAL128
ParsedBoundary(isUnBounded, if (isDecimal128) Left(special.value) else Right(special.value))
ParsedBoundary(
isUnbounded = special.isUnbounded,
value = orderByType.getTypeId match {
case DType.DTypeEnum.DECIMAL128 => RangeBoundaryValue.bigInt(special.value)
case DType.DTypeEnum.FLOAT32 | DType.DTypeEnum.FLOAT64 =>
RangeBoundaryValue.double(special.value)
case _ => RangeBoundaryValue.long(special.value)
}
)
case GpuLiteral(ci: CalendarInterval, CalendarIntervalType) =>
// Get the total microseconds for TIMESTAMP_MICROSECONDS
var x = TimeUnit.DAYS.toMicros(ci.days) + ci.microseconds
if (x == Long.MinValue) x = Long.MaxValue
ParsedBoundary(isUnbounded = false, Right(Math.abs(x)))
ParsedBoundary(isUnbounded = false, RangeBoundaryValue.long(Math.abs(x)))
case GpuLiteral(value, ByteType) =>
var x = value.asInstanceOf[Byte]
if (x == Byte.MinValue) x = Byte.MaxValue
ParsedBoundary(isUnbounded = false, Right(Math.abs(x)))
ParsedBoundary(isUnbounded = false, RangeBoundaryValue.long(Math.abs(x)))
case GpuLiteral(value, ShortType) =>
var x = value.asInstanceOf[Short]
if (x == Short.MinValue) x = Short.MaxValue
ParsedBoundary(isUnbounded = false, Right(Math.abs(x)))
ParsedBoundary(isUnbounded = false, RangeBoundaryValue.long(Math.abs(x)))
case GpuLiteral(value, IntegerType) =>
var x = value.asInstanceOf[Int]
if (x == Int.MinValue) x = Int.MaxValue
ParsedBoundary(isUnbounded = false, Right(Math.abs(x)))
ParsedBoundary(isUnbounded = false, RangeBoundaryValue.long(Math.abs(x)))
case GpuLiteral(value, LongType) =>
var x = value.asInstanceOf[Long]
if (x == Long.MinValue) x = Long.MaxValue
ParsedBoundary(isUnbounded = false, Right(Math.abs(x)))
ParsedBoundary(isUnbounded = false, RangeBoundaryValue.long(Math.abs(x)))
case GpuLiteral(value, FloatType) =>
var x = value.asInstanceOf[Float]
if (x == Float.MinValue) x = Float.MaxValue
ParsedBoundary(isUnbounded = false, RangeBoundaryValue.double(Math.abs(x)))
case GpuLiteral(value, DoubleType) =>
var x = value.asInstanceOf[Double]
if (x == Double.MinValue) x = Double.MaxValue
ParsedBoundary(isUnbounded = false, RangeBoundaryValue.double(Math.abs(x)))
case GpuLiteral(value: Decimal, DecimalType()) =>
orderByType.getTypeId match {
case DType.DTypeEnum.DECIMAL32 | DType.DTypeEnum.DECIMAL64 =>
ParsedBoundary(isUnbounded = false, Right(Math.abs(value.toUnscaledLong)))
ParsedBoundary(isUnbounded = false,
RangeBoundaryValue.long(Math.abs(value.toUnscaledLong)))
case DType.DTypeEnum.DECIMAL128 =>
ParsedBoundary(isUnbounded = false, Left(value.toJavaBigDecimal.unscaledValue().abs))
ParsedBoundary(isUnbounded = false,
RangeBoundaryValue.bigInt(value.toJavaBigDecimal.unscaledValue().abs))
case anythingElse =>
throw new UnsupportedOperationException(s"Unexpected Decimal type: $anythingElse")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,12 @@ abstract class GpuWindowExpressionMetaBase(
val orderSpec = wrapped.windowSpec.orderSpec
if (orderSpec.length > 1) {
// We only support a single order by column
willNotWorkOnGpu("only a single date/time or integral (Boolean exclusive)" +
willNotWorkOnGpu("only a single date/time or numeric (Boolean exclusive) " +
"based column in window range functions is supported")
}
val orderByTypeSupported = orderSpec.forall { so =>
so.dataType match {
case ByteType | ShortType | IntegerType | LongType |
case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType |
DateType | TimestampType | StringType | DecimalType() => true
case _ => false
}
Expand Down Expand Up @@ -134,6 +134,14 @@ abstract class GpuWindowExpressionMetaBase(
s"Range window frame is not 100% compatible when the order by type is " +
s"long and the range value calculated has overflow. " +
s"To enable it please set ${RapidsConf.ENABLE_RANGE_WINDOW_LONG} to true.")
case FloatType => if (!conf.isRangeWindowFloatEnabled) willNotWorkOnGpu(
s"Range window frame is not 100% compatible when the order by type is " +
s"float and the range value calculated has overflow. " +
s"To enable it please set ${RapidsConf.ENABLE_RANGE_WINDOW_FLOAT} to true.")
case DoubleType => if (!conf.isRangeWindowDoubleEnabled) willNotWorkOnGpu(
s"Range window frame is not 100% compatible when the order by type is " +
s"double and the range value calculated has overflow. " +
s"To enable it please set ${RapidsConf.ENABLE_RANGE_WINDOW_DOUBLE} to true.")
case DecimalType() => if (!conf.isRangeWindowDecimalEnabled) willNotWorkOnGpu(
s"To enable DECIMAL order by columns with Range window frames, " +
s"please set ${RapidsConf.ENABLE_RANGE_WINDOW_DECIMAL} to true.")
Expand All @@ -144,7 +152,7 @@ abstract class GpuWindowExpressionMetaBase(
// check whether the boundaries are supported or not.
Seq(spec.lower, spec.upper).foreach {
case l @ Literal(_, ByteType | ShortType | IntegerType |
LongType | DecimalType()) =>
LongType | FloatType | DoubleType | DecimalType()) =>
checkRangeBoundaryConfig(l.dataType)
case Literal(ci: CalendarInterval, CalendarIntervalType) =>
// interval is only working for TimeStampType
Expand Down Expand Up @@ -356,7 +364,7 @@ abstract class GpuSpecifiedWindowFrameMetaBase(
* Tag RangeFrame for other types and get the value
*/
def getAndTagOtherTypesForRangeFrame(bounds : Expression, isLower : Boolean): Long = {
willNotWorkOnGpu(s"Bounds for Range-based window frames must be specified in Integral" +
willNotWorkOnGpu(s"Bounds for Range-based window frames must be specified in numeric" +
s" type (Boolean exclusive) or CalendarInterval. Found ${bounds.dataType}")
if (isLower) -1 else 1 // not check again
}
Expand All @@ -377,36 +385,58 @@ abstract class GpuSpecifiedWindowFrameMetaBase(
return None
}

val value: BigInt = bounds match {
case Literal(value, ByteType) => value.asInstanceOf[Byte].toLong
case Literal(value, ShortType) => value.asInstanceOf[Short].toLong
case Literal(value, IntegerType) => value.asInstanceOf[Int].toLong
case Literal(value, LongType) => value.asInstanceOf[Long]
case Literal(value: Decimal, DecimalType()) => value.toJavaBigDecimal.unscaledValue()
/**
* Check bounds value relative to current row:
* 1. lower-bound should not be ahead of the current row.
* 2. upper-bound should not be behind the current row.
*/
def checkBounds[T](boundsValue: T)
(implicit ev: Numeric[T]): Option[String] = {
if (isLower && ev.compare(boundsValue, ev.zero) > 0) {
Some(s"Lower-bounds ahead of current row is not supported. Found: $boundsValue")
}
else if (!isLower && ev.compare(boundsValue, ev.zero) < 0) {
Some(s"Upper-bounds behind current row is not supported. Found: $boundsValue")
}
else {
None
}
}

bounds match {
case Literal(value, ByteType) =>
checkBounds(value.asInstanceOf[Byte].toLong)
case Literal(value, ShortType) =>
checkBounds(value.asInstanceOf[Short].toLong)
case Literal(value, IntegerType) =>
checkBounds(value.asInstanceOf[Int].toLong)
case Literal(value, LongType) =>
checkBounds(value.asInstanceOf[Long])
case Literal(value, FloatType) =>
checkBounds(value.asInstanceOf[Float])
case Literal(value, DoubleType) =>
checkBounds(value.asInstanceOf[Double])
case Literal(value: Decimal, DecimalType()) =>
checkBounds(BigInt(value.toJavaBigDecimal.unscaledValue()))
case Literal(ci: CalendarInterval, CalendarIntervalType) =>
if (ci.months != 0) {
willNotWorkOnGpu("interval months isn't supported")
}
// return the total microseconds
try {
Math.addExact(
Math.multiplyExact(ci.days.toLong, TimeUnit.DAYS.toMicros(1)),
ci.microseconds)
checkBounds(
Math.addExact(
Math.multiplyExact(ci.days.toLong, TimeUnit.DAYS.toMicros(1)),
ci.microseconds))
} catch {
case _: ArithmeticException =>
willNotWorkOnGpu("windows over timestamps are converted to microseconds " +
s"and $ci is too large to fit")
if (isLower) -1 else 1 // not check again
None
}
case _ => getAndTagOtherTypesForRangeFrame(bounds, isLower)
}

if (isLower && value > 0) {
Some(s"Lower-bounds ahead of current row is not supported. Found: $value")
} else if (!isLower && value < 0) {
Some(s"Upper-bounds behind current row is not supported. Found: $value")
} else {
None
case _ =>
getAndTagOtherTypesForRangeFrame(bounds, isLower)
None
}
}

Expand Down
18 changes: 18 additions & 0 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1266,6 +1266,20 @@ object RapidsConf {
.booleanConf
.createWithDefault(true)

val ENABLE_RANGE_WINDOW_FLOAT: ConfEntryWithDefault[Boolean] =
conf("spark.rapids.sql.window.range.float.enabled")
.doc("When set to false, this disables the range window acceleration for the " +
"FLOAT type order-by column")
.booleanConf
.createWithDefault(true)

val ENABLE_RANGE_WINDOW_DOUBLE: ConfEntryWithDefault[Boolean] =
conf("spark.rapids.sql.window.range.double.enabled")
.doc("When set to false, this disables the range window acceleration for the " +
"double type order-by column")
.booleanConf
.createWithDefault(true)

val ENABLE_RANGE_WINDOW_DECIMAL: ConfEntryWithDefault[Boolean] =
conf("spark.rapids.sql.window.range.decimal.enabled")
.doc("When set to false, this disables the range window acceleration for the " +
Expand Down Expand Up @@ -2569,6 +2583,10 @@ class RapidsConf(conf: Map[String, String]) extends Logging {

lazy val isRangeWindowLongEnabled: Boolean = get(ENABLE_RANGE_WINDOW_LONG)

lazy val isRangeWindowFloatEnabled: Boolean = get(ENABLE_RANGE_WINDOW_FLOAT)

lazy val isRangeWindowDoubleEnabled: Boolean = get(ENABLE_RANGE_WINDOW_DOUBLE)

lazy val isRangeWindowDecimalEnabled: Boolean = get(ENABLE_RANGE_WINDOW_DECIMAL)

lazy val allowSinglePassPartialSortAgg: Boolean = get(ENABLE_SINGLE_PASS_PARTIAL_SORT_AGG)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,11 +227,11 @@ trait Spark320PlusShims extends SparkShims with RebaseShims with Logging {
Seq(
ParamCheck("lower",
TypeSig.CALENDAR + TypeSig.NULL + TypeSig.integral + TypeSig.DAYTIME
+ TypeSig.DECIMAL_128,
+ TypeSig.DECIMAL_128 + TypeSig.FLOAT + TypeSig.DOUBLE,
TypeSig.numericAndInterval),
ParamCheck("upper",
TypeSig.CALENDAR + TypeSig.NULL + TypeSig.integral + TypeSig.DAYTIME
+ TypeSig.DECIMAL_128,
+ TypeSig.DECIMAL_128 + TypeSig.FLOAT + TypeSig.DOUBLE,
TypeSig.numericAndInterval))),
(windowFrame, conf, p, r) => new GpuSpecifiedWindowFrameMeta(windowFrame, conf, p, r)),
GpuOverrides.expr[WindowExpression](
Expand Down
Loading

0 comments on commit 98b6528

Please sign in to comment.