Skip to content

Commit

Permalink
Enable sort for single-level nesting struct columns on GPU (NVIDIA#1883)
Browse files Browse the repository at this point in the history
Adds single-level struct columns to sort. This PR contributes to NVIDIA#1605 

The following limitations apply with this PR for a total sort, and will be resolved in follow-up PR's
- only if the number of partitions is 1 
- only if spark.rapids.sql.stableSort.enabled is true
 
Signed-off-by: Gera Shegalov <gera@apache.org>
  • Loading branch information
gerashegalov authored Mar 26, 2021
1 parent d509a60 commit 8e6762d
Show file tree
Hide file tree
Showing 7 changed files with 185 additions and 33 deletions.
10 changes: 5 additions & 5 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,12 +144,12 @@ Accelerator supports are described below.
<td>S*</td>
<td>S</td>
<td>S*</td>
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT)</em></td>
<td><b>NS</b></td>
</tr>
<tr>
Expand Down Expand Up @@ -379,7 +379,7 @@ Accelerator supports are described below.
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT)</em></td>
<td><b>NS</b></td>
</tr>
<tr>
Expand Down Expand Up @@ -12421,7 +12421,7 @@ Accelerator support is described below.
<td><b>NS</b></td>
<td><b>NS</b></td>
<td> </td>
<td><b>NS</b></td>
<td><em>PS* (missing nested BINARY, CALENDAR, ARRAY, STRUCT, UDT)</em></td>
<td><b>NS</b></td>
</tr>
<tr>
Expand All @@ -12442,7 +12442,7 @@ Accelerator support is described below.
<td><b>NS</b></td>
<td><b>NS</b></td>
<td> </td>
<td><b>NS</b></td>
<td><em>PS* (missing nested BINARY, CALENDAR, ARRAY, STRUCT, UDT)</em></td>
<td><b>NS</b></td>
</tr>
<tr>
Expand Down
30 changes: 24 additions & 6 deletions integration_tests/src/main/python/data_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def start(self, rand):
POS_FLOAT_NAN_MAX_VALUE = struct.unpack('f', struct.pack('I', 0x7fffffff))[0]
class FloatGen(DataGen):
"""Generate floats, which some built in corner cases."""
def __init__(self, nullable=True,
def __init__(self, nullable=True,
no_nans=False, special_cases=None):
self._no_nans = no_nans
if special_cases is None:
Expand Down Expand Up @@ -334,7 +334,7 @@ def gen_float():
POS_DOUBLE_NAN_MAX_VALUE = struct.unpack('d', struct.pack('L', 0x7fffffffffffffff))[0]
class DoubleGen(DataGen):
"""Generate doubles, which some built in corner cases."""
def __init__(self, min_exp=DOUBLE_MIN_EXP, max_exp=DOUBLE_MAX_EXP, no_nans=False,
def __init__(self, min_exp=DOUBLE_MIN_EXP, max_exp=DOUBLE_MAX_EXP, no_nans=False,
nullable=True, special_cases = None):
self._min_exp = min_exp
self._max_exp = max_exp
Expand Down Expand Up @@ -447,7 +447,7 @@ def __init__(self, start=None, end=None, nullable=True):

self._start_day = self._to_days_since_epoch(start)
self._end_day = self._to_days_since_epoch(end)

self.with_special_case(start)
self.with_special_case(end)

Expand Down Expand Up @@ -652,9 +652,27 @@ def gen_scalar_value(data_gen, seed=0, force_no_nulls=False):
v = list(gen_scalar_values(data_gen, 1, seed=seed, force_no_nulls=force_no_nulls))
return v[0]

def debug_df(df):
"""print out the contents of a dataframe for debugging."""
print('COLLECTED\n{}'.format(df.collect()))
def debug_df(df, path = None, file_format = 'json', num_parts = 1):
"""Print out or save the contents and the schema of a dataframe for debugging."""

if path is not None:
# Save the dataframe and its schema
# The schema can be re-created by using DataType.fromJson and used
# for loading the dataframe
file_name = f"{path}.{file_format}"
schema_file_name = f"{path}.schema.json"

df.coalesce(num_parts).write.format(file_format).save(file_name)
print(f"SAVED df output for debugging at {file_name}")

schema_json = df.schema.json()
schema_file = open(schema_file_name , 'w')
schema_file.write(schema_json)
schema_file.close()
print(f"SAVED df schema for debugging along in the output dir")
else:
print('COLLECTED\n{}'.format(df.collect()))

df.explain()
df.printSchema()
return df
Expand Down
66 changes: 66 additions & 0 deletions integration_tests/src/main/python/sort_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,46 @@ def test_single_orderby(data_gen, order):
lambda spark : unary_op_df(spark, data_gen).orderBy(order),
conf = allow_negative_scale_of_decimal_conf)

@pytest.mark.parametrize('shuffle_parts', [
pytest.param(1),
pytest.param(200, marks=pytest.mark.xfail(reason="https://github.com/NVIDIA/spark-rapids/issues/1607"))
])
@pytest.mark.parametrize('stable_sort', [
pytest.param(True),
pytest.param(False, marks=pytest.mark.xfail(reason="https://github.com/NVIDIA/spark-rapids/issues/1607"))
])
@pytest.mark.parametrize('data_gen', [
pytest.param(all_basic_struct_gen),
pytest.param(StructGen([['child0', all_basic_struct_gen]]),
marks=pytest.mark.xfail(reason='second-level structs are not supported')),
pytest.param(ArrayGen(string_gen),
marks=pytest.mark.xfail(reason="arrays are not supported")),
pytest.param(MapGen(StringGen(pattern='key_[0-9]', nullable=False), simple_string_to_string_map_gen),
marks=pytest.mark.xfail(reason="maps are not supported")),
], ids=idfn)
@pytest.mark.parametrize('order', [
pytest.param(f.col('a').asc()),
pytest.param(f.col('a').asc_nulls_first()),
pytest.param(f.col('a').asc_nulls_last(),
marks=pytest.mark.xfail(reason='opposite null order not supported')),
pytest.param(f.col('a').desc()),
pytest.param(f.col('a').desc_nulls_first(),
marks=pytest.mark.xfail(reason='opposite null order not supported')),
pytest.param(f.col('a').desc_nulls_last()),
], ids=idfn)
def test_single_nested_orderby_plain(data_gen, order, shuffle_parts, stable_sort):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).orderBy(order),
# TODO no interference with range partition once implemented
conf = {
**allow_negative_scale_of_decimal_conf,
**{
'spark.sql.shuffle.partitions': shuffle_parts,
'spark.rapids.sql.stableSort.enabled': stable_sort,
'spark.rapids.allowCpuRangePartitioning': False
}
})

# SPARK CPU itself has issue with negative scale for take ordered and project
orderable_without_neg_decimal = [n for n in (orderable_gens + orderable_not_null_gen) if not (isinstance(n, DecimalGen) and n.scale < 0)]
@pytest.mark.parametrize('data_gen', orderable_without_neg_decimal, ids=idfn)
Expand All @@ -42,6 +82,32 @@ def test_single_orderby_with_limit(data_gen, order):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).orderBy(order).limit(100))

@pytest.mark.parametrize('data_gen', [
pytest.param(all_basic_struct_gen),
pytest.param(StructGen([['child0', all_basic_struct_gen]]),
marks=pytest.mark.xfail(reason='second-level structs are not supported')),
pytest.param(ArrayGen(string_gen),
marks=pytest.mark.xfail(reason="arrays are not supported")),
pytest.param(MapGen(StringGen(pattern='key_[0-9]', nullable=False), simple_string_to_string_map_gen),
marks=pytest.mark.xfail(reason="maps are not supported")),
], ids=idfn)
@pytest.mark.parametrize('order', [
pytest.param(f.col('a').asc()),
pytest.param(f.col('a').asc_nulls_first()),
pytest.param(f.col('a').asc_nulls_last(),
marks=pytest.mark.xfail(reason='opposite null order not supported')),
pytest.param(f.col('a').desc()),
pytest.param(f.col('a').desc_nulls_first(),
marks=pytest.mark.xfail(reason='opposite null order not supported')),
pytest.param(f.col('a').desc_nulls_last()),
], ids=idfn)
def test_single_nested_orderby_with_limit(data_gen, order):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).orderBy(order).limit(100),
conf = {
'spark.rapids.allowCpuRangePartitioning': False
})

@pytest.mark.parametrize('data_gen', orderable_gens + orderable_not_null_gen, ids=idfn)
@pytest.mark.parametrize('order', [f.col('a').asc(), f.col('a').asc_nulls_last(), f.col('a').desc(), f.col('a').desc_nulls_first()], ids=idfn)
def test_single_sort_in_part(data_gen, order):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,16 @@ object GpuOverrides {
"\\S", "\\v", "\\V", "\\w", "\\w", "\\p", "$", "\\b", "\\B", "\\A", "\\G", "\\Z", "\\z", "\\R",
"?", "|", "(", ")", "{", "}", "\\k", "\\Q", "\\E", ":", "!", "<=", ">")

private[this] val _commonTypes = TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL

private[this] val pluginSupportedOrderableSig = _commonTypes +
TypeSig.STRUCT.nested(_commonTypes)

private[this] def isStructType(dataType: DataType) = dataType match {
case StructType(_) => true
case _ => false
}

// this listener mechanism is global and is intended for use by unit tests only
private val listeners: ListBuffer[GpuOverridesListener] = new ListBuffer[GpuOverridesListener]()

Expand Down Expand Up @@ -1814,16 +1824,28 @@ object GpuOverrides {
expr[SortOrder](
"Sort order",
ExprChecks.projectOnly(
TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL,
pluginSupportedOrderableSig,
TypeSig.orderable,
Seq(ParamCheck(
"input",
TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL,
pluginSupportedOrderableSig,
TypeSig.orderable))),
(a, conf, p, r) => new BaseExprMeta[SortOrder](a, conf, p, r) {
(sortOrder, conf, p, r) => new BaseExprMeta[SortOrder](sortOrder, conf, p, r) {
override def tagExprForGpu(): Unit = {
if (isStructType(sortOrder.dataType)) {
val nullOrdering = sortOrder.nullOrdering
val directionDefaultNullOrdering = sortOrder.direction.defaultNullOrdering
val direction = sortOrder.direction.sql
if (nullOrdering != directionDefaultNullOrdering) {
willNotWorkOnGpu(s"only default null ordering $directionDefaultNullOrdering " +
s"for direction $direction is supported for nested types; actual: ${nullOrdering}")
}
}
}

// One of the few expressions that are not replaced with a GPU version
override def convertToGpu(): Expression =
a.withNewChildren(childExprs.map(_.convertToGpu()))
sortOrder.withNewChildren(childExprs.map(_.convertToGpu()))
}),
expr[Count](
"Count aggregate operator",
Expand Down Expand Up @@ -2499,6 +2521,14 @@ object GpuOverrides {
override val childExprs: Seq[BaseExprMeta[_]] =
rp.ordering.map(GpuOverrides.wrapExpr(_, conf, Some(this)))

override def tagPartForGpu() {
val numPartitions = rp.numPartitions
if (numPartitions > 1 && rp.ordering.exists(so => isStructType(so.dataType))) {
willNotWorkOnGpu("only single partition sort is supported for nested types, " +
s"actual partitions: $numPartitions")
}
}

override def convertToGpu(): GpuPartitioning = {
if (rp.numPartitions > 1) {
val gpuOrdering = childExprs.map(_.convertToGpu()).asInstanceOf[Seq[SortOrder]]
Expand Down Expand Up @@ -2612,7 +2642,7 @@ object GpuOverrides {
}),
exec[TakeOrderedAndProjectExec](
"Take the first limit elements as defined by the sortOrder, and do projection if needed.",
ExecChecks(TypeSig.commonCudfTypes + TypeSig.DECIMAL + TypeSig.NULL, TypeSig.all),
ExecChecks(pluginSupportedOrderableSig, TypeSig.all),
(takeExec, conf, p, r) =>
new SparkPlanMeta[TakeOrderedAndProjectExec](takeExec, conf, p, r) {
val sortOrder: Seq[BaseExprMeta[SortOrder]] =
Expand Down Expand Up @@ -2678,7 +2708,7 @@ object GpuOverrides {
}),
exec[CollectLimitExec](
"Reduce to single partition and apply limit",
ExecChecks(TypeSig.commonCudfTypes + TypeSig.DECIMAL, TypeSig.all),
ExecChecks(pluginSupportedOrderableSig, TypeSig.all),
(collectLimitExec, conf, p, r) => new GpuCollectLimitMeta(collectLimitExec, conf, p, r))
.disabledByDefault("Collect Limit replacement can be slower on the GPU, if huge number " +
"of rows in a batch it could help by limiting the number of rows transferred from " +
Expand Down Expand Up @@ -2751,9 +2781,16 @@ object GpuOverrides {
"The backend for the sort operator",
// The SortOrder TypeSig will govern what types can actually be used as sorting key data type.
// The types below are allowed as inputs and outputs.
ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL + TypeSig.ARRAY +
TypeSig.STRUCT).nested(), TypeSig.all),
(sort, conf, p, r) => new GpuSortMeta(sort, conf, p, r)),
ExecChecks(pluginSupportedOrderableSig + (TypeSig.ARRAY + TypeSig.STRUCT).nested(),
TypeSig.all),
(sort, conf, p, r) => new GpuSortMeta(sort, conf, p, r) {
override def tagPlanForGpu() {
if (!conf.stableSort && sort.sortOrder.exists(so => isStructType(so.dataType))) {
willNotWorkOnGpu("it's disabled for nested types " +
s"unless ${RapidsConf.STABLE_SORT.key} is true")
}
}
}),
exec[ExpandExec](
"The backend for the expand operator",
ExecChecks(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL, TypeSig.all),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

package com.nvidia.spark.rapids

import org.apache.spark.RangePartitioner
import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeReference, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName, SortOrder}
import org.apache.spark.sql.catalyst.plans.physical.RangePartitioning
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, CustomShuffleReaderExec, QueryStageExec, ShuffleQueryStageExec}
Expand Down Expand Up @@ -406,8 +408,11 @@ class GpuTransitionOverrides extends Rule[SparkPlan] {
case _: GpuColumnarToRowExecParent => () // Ignored
case _: ExecutedCommandExec => () // Ignored
case _: RDDScanExec => () // Ignored
case _: ShuffleExchangeExec => () // Ignored for now, we don't force it to the GPU if
// children are not on the gpu
case shuffleExchange: ShuffleExchangeExec if conf.cpuRangePartitioningPermitted
|| !shuffleExchange.outputPartitioning.isInstanceOf[RangePartitioning] => {
// Ignored for now, we don't force it to the GPU if
// children are not on the gpu
}
case other =>
if (!plan.supportsColumnar &&
!conf.testingAllowedNonGpu.contains(getBaseNameFromClass(other.getClass.toString))) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -961,6 +961,12 @@ object RapidsConf {
.booleanConf
.createWithDefault(true)

val CPU_RANGE_PARTITIONING_ALLOWED = conf("spark.rapids.allowCpuRangePartitioning")
.doc("Option to control enforcement of range partitioning on GPU.")
.internal()
.booleanConf
.createWithDefault(true)

private def printSectionHeader(category: String): Unit =
println(s"\n### $category")

Expand Down Expand Up @@ -1287,6 +1293,8 @@ class RapidsConf(conf: Map[String, String]) extends Logging {

lazy val getAlluxioPathsToReplace: Option[Seq[String]] = get(ALLUXIO_PATHS_REPLACE)

lazy val cpuRangePartitioningPermitted = get(CPU_RANGE_PARTITIONING_ALLOWED)

def isOperatorEnabled(key: String, incompat: Boolean, isDisabledByDefault: Boolean): Boolean = {
val default = !(isDisabledByDefault || incompat) || (incompat && isIncompatEnabled)
conf.get(key).map(toBoolean(_, key)).getOrElse(default)
Expand Down
40 changes: 29 additions & 11 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,19 @@ final class TypeSig private(
new TypeSig(it, nt, lt, nts)
}

/**
* Remove a type signature. The reverse of +
* @param other what to remove
* @return the new signature
*/
def - (other: TypeSig): TypeSig = {
val it = initialTypes -- other.initialTypes
val nt = nestedTypes -- other.nestedTypes
val lt = litOnlyTypes -- other.litOnlyTypes
val nts = notes -- other.notes.keySet
new TypeSig(it, nt, lt, nts)
}

/**
* Add nested types to this type signature. Note that these do not stack so if nesting has
* nested types too they are ignored.
Expand Down Expand Up @@ -542,18 +555,23 @@ class ExecChecks private(
override def tag(meta: RapidsMeta[_, _, _]): Unit = {
val plan = meta.wrapped.asInstanceOf[SparkPlan]
val allowDecimal = meta.conf.decimalTypeEnabled
if (!check.areAllSupportedByPlugin(plan.output.map(_.dataType), allowDecimal)) {
val unsupported = plan.output.map(_.dataType)
.filter(!check.isSupportedByPlugin(_, allowDecimal))
.toSet
meta.willNotWorkOnGpu(s"unsupported data types in output: ${unsupported.mkString(", ")}")

val unsupportedOutputTypes = plan.output
.filterNot(attr => check.isSupportedByPlugin(attr.dataType, allowDecimal))
.toSet

if (unsupportedOutputTypes.nonEmpty) {
meta.willNotWorkOnGpu("unsupported data types in output: " +
unsupportedOutputTypes.mkString(", "))
}
if (!check.areAllSupportedByPlugin(
plan.children.flatMap(_.output.map(_.dataType)),
allowDecimal)) {
val unsupported = plan.children.flatMap(_.output.map(_.dataType))
.filter(!check.isSupportedByPlugin(_, allowDecimal)).toSet
meta.willNotWorkOnGpu(s"unsupported data types in input: ${unsupported.mkString(", ")}")

val unsupportedInputTypes = plan.children.flatMap { childPlan =>
childPlan.output.filterNot(attr => check.isSupportedByPlugin(attr.dataType, allowDecimal))
}.toSet

if (unsupportedInputTypes.nonEmpty) {
meta.willNotWorkOnGpu("unsupported data types in input: " +
unsupportedInputTypes.mkString(", "))
}
}

Expand Down

0 comments on commit 8e6762d

Please sign in to comment.