Skip to content

Commit

Permalink
Adds in basic support for decimal sort, sum, and some shuffle (#1380)
Browse files Browse the repository at this point in the history
Signed-off-by: Robert (Bobby) Evans <bobby@apache.org>
  • Loading branch information
revans2 authored Dec 14, 2020
1 parent 4b9bc34 commit 5c43fd3
Show file tree
Hide file tree
Showing 12 changed files with 199 additions and 29 deletions.
2 changes: 2 additions & 0 deletions docs/configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ Name | SQL Function(s) | Description | Default Value | Notes
<a name="sql.expression.Log2"></a>spark.rapids.sql.expression.Log2|`log2`|Log base 2|true|None|
<a name="sql.expression.Logarithm"></a>spark.rapids.sql.expression.Logarithm|`log`|Log variable base|true|None|
<a name="sql.expression.Lower"></a>spark.rapids.sql.expression.Lower|`lower`, `lcase`|String lowercase operator|false|This is not 100% compatible with the Spark version because in some cases unicode characters change byte width when changing the case. The GPU string conversion does not support these characters. For a full list of unsupported characters see https://github.com/rapidsai/cudf/issues/3132|
<a name="sql.expression.MakeDecimal"></a>spark.rapids.sql.expression.MakeDecimal| |Create a Decimal from an unscaled long value form some aggregation optimizations|true|None|
<a name="sql.expression.Md5"></a>spark.rapids.sql.expression.Md5|`md5`|MD5 hash operator|true|None|
<a name="sql.expression.Minute"></a>spark.rapids.sql.expression.Minute|`minute`|Returns the minute component of the string/timestamp|true|None|
<a name="sql.expression.MonotonicallyIncreasingID"></a>spark.rapids.sql.expression.MonotonicallyIncreasingID|`monotonically_increasing_id`|Returns monotonically increasing 64-bit integers|true|None|
Expand Down Expand Up @@ -228,6 +229,7 @@ Name | SQL Function(s) | Description | Default Value | Notes
<a name="sql.expression.UnboundedFollowing$"></a>spark.rapids.sql.expression.UnboundedFollowing$| |Special boundary for a window frame, indicating all rows preceding the current row|true|None|
<a name="sql.expression.UnboundedPreceding$"></a>spark.rapids.sql.expression.UnboundedPreceding$| |Special boundary for a window frame, indicating all rows preceding the current row|true|None|
<a name="sql.expression.UnixTimestamp"></a>spark.rapids.sql.expression.UnixTimestamp|`unix_timestamp`|Returns the UNIX timestamp of current or specified time|true|None|
<a name="sql.expression.UnscaledValue"></a>spark.rapids.sql.expression.UnscaledValue| |Convert a Decimal to an unscaled long value for some aggregation optimizations|true|None|
<a name="sql.expression.Upper"></a>spark.rapids.sql.expression.Upper|`upper`, `ucase`|String uppercase operator|false|This is not 100% compatible with the Spark version because in some cases unicode characters change byte width when changing the case. The GPU string conversion does not support these characters. For a full list of unsupported characters see https://github.com/rapidsai/cudf/issues/3132|
<a name="sql.expression.WeekDay"></a>spark.rapids.sql.expression.WeekDay|`weekday`|Returns the day of the week (0 = Monday...6=Sunday)|true|None|
<a name="sql.expression.WindowExpression"></a>spark.rapids.sql.expression.WindowExpression| |Calculates a return value for every input row of a table based on a group (or "window") of rows|true|None|
Expand Down
5 changes: 4 additions & 1 deletion integration_tests/src/main/python/data_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class DataGen:
"""Base class for data generation"""

def __repr__(self):
if not self.nullable:
return self.__class__.__name__[:-3] + '(not_null)'
return self.__class__.__name__[:-3]

def __hash__(self):
Expand Down Expand Up @@ -722,6 +724,7 @@ def gen_scalars_for_sql(data_gen, count, seed=0, force_no_nulls=False):
decimal_gen_neg_scale = DecimalGen(precision=7, scale=-3)
decimal_gen_scale_precision = DecimalGen(precision=7, scale=3)
decimal_gen_same_scale_precision = DecimalGen(precision=7, scale=7)
decimal_gen_64bit = DecimalGen(precision=12, scale=2)

null_gen = NullGen()

Expand All @@ -734,7 +737,7 @@ def gen_scalars_for_sql(data_gen, count, seed=0, force_no_nulls=False):
double_n_long_gens = [double_gen, long_gen]
int_n_long_gens = [int_gen, long_gen]
decimal_gens = [decimal_gen_default, decimal_gen_neg_scale, decimal_gen_scale_precision,
decimal_gen_same_scale_precision]
decimal_gen_same_scale_precision, decimal_gen_64bit]

# all of the basic gens
all_basic_gens = [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,
Expand Down
12 changes: 10 additions & 2 deletions integration_tests/src/main/python/hash_aggregate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def get_params(init_list, marked_params=[]):
# Pytest marker for list of operators allowed to run on the CPU,
# esp. useful in partial and final only modes.
_excluded_operators_marker = pytest.mark.allow_non_gpu(
'HashAggregateExec', 'AggregateExpression',
'HashAggregateExec', 'AggregateExpression', 'UnscaledValue', 'MakeDecimal',
'AttributeReference', 'Alias', 'Sum', 'Count', 'Max', 'Min', 'Average', 'Cast',
'KnownFloatingPointNormalized', 'NormalizeNaNAndZero', 'GreaterThan', 'Literal', 'If',
'EqualTo', 'First', 'SortAggregateExec', 'Coalesce')
Expand All @@ -173,10 +173,18 @@ def get_params(init_list, marked_params=[]):
]


_grpkey_small_decimals = [
('a', RepeatSeqGen(DecimalGen(precision=7, scale=3, nullable=(True, 10.0)), length=50)),
('b', DecimalGen(precision=5, scale=2)),
('c', DecimalGen(precision=8, scale=3))]

_init_list_no_nans_with_decimal = _init_list_no_nans + [
_grpkey_small_decimals]

@approximate_float
@ignore_order
@incompat
@pytest.mark.parametrize('data_gen', _init_list_no_nans, ids=idfn)
@pytest.mark.parametrize('data_gen', _init_list_no_nans_with_decimal, ids=idfn)
@pytest.mark.parametrize('conf', get_params(_confs, params_markers_for_confs), ids=idfn)
def test_hash_grpby_sum(data_gen, conf):
assert_gpu_and_cpu_are_equal_collect(
Expand Down
5 changes: 3 additions & 2 deletions integration_tests/src/main/python/repart_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def test_coalesce_df(num_parts, length):
@ignore_order(local=True) # To avoid extra data shuffle by 'sort on Spark' for this repartition test.
def test_repartion_df(num_parts, length):
#This should change eventually to be more than just the basic gens
gen_list = [('_c' + str(i), gen) for i, gen in enumerate(all_basic_gens)]
gen_list = [('_c' + str(i), gen) for i, gen in enumerate(all_basic_gens + decimal_gens)]
assert_gpu_and_cpu_are_equal_collect(
lambda spark : gen_df(spark, gen_list, length=length).repartition(num_parts))
lambda spark : gen_df(spark, gen_list, length=length).repartition(num_parts),
conf = allow_negative_scale_of_decimal_conf)
36 changes: 16 additions & 20 deletions integration_tests/src/main/python/sort_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,36 +20,32 @@
from pyspark.sql.types import *
import pyspark.sql.functions as f

orderable_gen_classes = [ByteGen, ShortGen, IntegerGen, LongGen, FloatGen, DoubleGen,
BooleanGen, TimestampGen, DateGen, StringGen, NullGen]
orderable_not_null_gen = [ByteGen(nullable=False), ShortGen(nullable=False), IntegerGen(nullable=False),
LongGen(nullable=False), FloatGen(nullable=False), DoubleGen(nullable=False), BooleanGen(nullable=False),
TimestampGen(nullable=False), DateGen(nullable=False), StringGen(nullable=False), DecimalGen(nullable=False),
DecimalGen(precision=7, scale=-3, nullable=False), DecimalGen(precision=7, scale=3, nullable=False),
DecimalGen(precision=7, scale=7, nullable=False), DecimalGen(precision=12, scale=2, nullable=False)]

@pytest.mark.parametrize('data_gen_class', orderable_gen_classes, ids=idfn)
@pytest.mark.parametrize('nullable', [True, False], ids=idfn)
@pytest.mark.parametrize('data_gen', orderable_gens + orderable_not_null_gen, ids=idfn)
@pytest.mark.parametrize('order', [f.col('a').asc(), f.col('a').asc_nulls_last(), f.col('a').desc(), f.col('a').desc_nulls_first()], ids=idfn)
def test_single_orderby(data_gen_class, nullable, order):
if (data_gen_class == NullGen):
data_gen = data_gen_class()
else:
data_gen = data_gen_class(nullable=nullable)
def test_single_orderby(data_gen, order):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).orderBy(order))
lambda spark : unary_op_df(spark, data_gen).orderBy(order),
conf = allow_negative_scale_of_decimal_conf)

@pytest.mark.parametrize('data_gen_class', orderable_gen_classes, ids=idfn)
@pytest.mark.parametrize('nullable', [True, False], ids=idfn)
@pytest.mark.parametrize('data_gen', orderable_gens + orderable_not_null_gen, ids=idfn)
@pytest.mark.parametrize('order', [f.col('a').asc(), f.col('a').asc_nulls_last(), f.col('a').desc(), f.col('a').desc_nulls_first()], ids=idfn)
def test_single_sort_in_part(data_gen_class, nullable, order):
if (data_gen_class == NullGen):
data_gen = data_gen_class()
else:
data_gen = data_gen_class(nullable=nullable)
def test_single_sort_in_part(data_gen, order):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).sortWithinPartitions(order))
lambda spark : unary_op_df(spark, data_gen).sortWithinPartitions(order),
conf = allow_negative_scale_of_decimal_conf)

orderable_gens_sort = [byte_gen, short_gen, int_gen, long_gen,
pytest.param(float_gen, marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/84')),
pytest.param(double_gen, marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/84')),
boolean_gen, timestamp_gen, date_gen, string_gen, null_gen]
boolean_gen, timestamp_gen, date_gen, string_gen, null_gen] + decimal_gens
@pytest.mark.parametrize('data_gen', orderable_gens_sort, ids=idfn)
def test_multi_orderby(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : binary_op_df(spark, data_gen).orderBy(f.col('a'), f.col('b').desc()))
lambda spark : binary_op_df(spark, data_gen).orderBy(f.col('a'), f.col('b').desc()),
conf = allow_negative_scale_of_decimal_conf)
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,59 @@
*/
public class GpuColumnVector extends GpuColumnVectorBase {

/**
* Print to standard error the contents of a table. Note that this should never be
* called from production code, as it is very slow. Also note that this is not production
* code. You might need/want to update how the data shows up or add in support for more
* types as this really is just for debugging.
* @param name the name of the table to print out.
* @param table the table to print out.
*/
public static synchronized void debug(String name, Table table) {
System.err.println("DEBUG " + name + " " + table);
for (int col = 0; col < table.getNumberOfColumns(); col++) {
debug(String.valueOf(col), table.getColumn(col));
}
}

/**
* Print to standard error the contents of a column. Note that this should never be
* called from production code, as it is very slow. Also note that this is not production
* code. You might need/want to update how the data shows up or add in support for more
* types as this really is just for debugging.
* @param name the name of the column to print out.
* @param col the column to print out.
*/
public static synchronized void debug(String name, ai.rapids.cudf.ColumnVector col) {
try (HostColumnVector hostCol = col.copyToHost()) {
debug(name, hostCol);
}
}

/**
* Print to standard error the contents of a column. Note that this should never be
* called from production code, as it is very slow. Also note that this is not production
* code. You might need/want to update how the data shows up or add in support for more
* types as this really is just for debugging.
* @param name the name of the column to print out.
* @param hostCol the column to print out.
*/
public static synchronized void debug(String name, HostColumnVector hostCol) {
DType type = hostCol.getType();
System.err.println("COLUMN " + name + " " + type);
if (type.getTypeId() == DType.DTypeEnum.DECIMAL64) {
for (int i = 0; i < hostCol.getRowCount(); i++) {
if (hostCol.isNull(i)) {
System.err.println(i + " NULL");
} else {
System.err.println(i + " " + hostCol.getBigDecimal(i));
}
}
} else {
System.err.println("TYPE " + type + " NOT SUPPORTED FOR DEBUG PRINT");
}
}

private static HostColumnVector.DataType convertFrom(DataType spark, boolean nullable) {
if (spark instanceof ArrayType) {
ArrayType arrayType = (ArrayType) spark;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@ public final ColumnarMap getMap(int ordinal) {
public final Decimal getDecimal(int rowId, int precision, int scale) {
assert precision <= DType.DECIMAL64_MAX_PRECISION : "Assert " + precision + " <= DECIMAL64_MAX_PRECISION(" + DType.DECIMAL64_MAX_PRECISION + ")";
assert cudfCv.getType().getTypeId() == DType.DTypeEnum.DECIMAL64: "Assert DType to be DECIMAL64";
assert scale == -cudfCv.getType().getScale() : "Assert fetch decimal with its original scale";
assert scale == -cudfCv.getType().getScale() :
"Assert fetch decimal with its original scale " + scale + " expected " + (-cudfCv.getType().getScale());
return Decimal.createUnsafe(cudfCv.getLong(rowId), precision, scale);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1477,7 +1477,9 @@ object GpuOverrides {
}
override def isSupportedType(t: DataType): Boolean =
GpuOverrides.isSupportedType(t,
allowDecimal = conf.decimalTypeEnabled,
allowNull = true)

override def convertToGpu(): GpuExpression = {
// handle the case AggregateExpression has the resultIds parameter where its
// Seq[ExprIds] instead of single ExprId.
Expand All @@ -1501,7 +1503,9 @@ object GpuOverrides {
a.withNewChildren(childExprs.map(_.convertToGpu()))

override def isSupportedType(t: DataType): Boolean =
GpuOverrides.isSupportedType(t, allowNull = true)
GpuOverrides.isSupportedType(t,
allowDecimal = conf.decimalTypeEnabled,
allowNull = true)
}),
expr[Count](
"Count aggregate operator",
Expand Down Expand Up @@ -1871,6 +1875,25 @@ object GpuOverrides {
"String character length",
(a, conf, p, r) => new UnaryExprMeta[Length](a, conf, p, r) {
override def convertToGpu(child: Expression): GpuExpression = GpuLength(child)
}),
expr[UnscaledValue](
"Convert a Decimal to an unscaled long value for some aggregation optimizations",
(a, conf, p, r) => new UnaryExprMeta[UnscaledValue](a, conf, p, r) {
override def convertToGpu(child: Expression): GpuExpression = GpuUnscaledValue(child)

override def isSupportedType(t: DataType): Boolean =
GpuOverrides.isSupportedType(t,
allowDecimal = conf.decimalTypeEnabled)
}),
expr[MakeDecimal](
"Create a Decimal from an unscaled long value form some aggregation optimizations",
(a, conf, p, r) => new UnaryExprMeta[MakeDecimal](a, conf, p, r) {
override def convertToGpu(child: Expression): GpuExpression =
GpuMakeDecimal(child, a.precision, a.scale, a.nullOnOverflow)

override def isSupportedType(t: DataType): Boolean =
GpuOverrides.isSupportedType(t,
allowDecimal = conf.decimalTypeEnabled)
})
).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class GpuSortMeta(

override def isSupportedType(t: DataType): Boolean =
GpuOverrides.isSupportedType(t,
allowDecimal = conf.decimalTypeEnabled,
allowNull = true)

override def tagPlanForGpu(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ class GpuHashAggregateMeta(
override def isSupportedType(t: DataType): Boolean =
GpuOverrides.isSupportedType(t,
allowNull = true,
allowDecimal = conf.decimalTypeEnabled,
allowStringMaps = true)

override def tagPlanForGpu(): Unit = {
Expand Down Expand Up @@ -270,6 +271,7 @@ class GpuSortAggregateMeta(

override def isSupportedType(t: DataType): Boolean =
GpuOverrides.isSupportedType(t,
allowDecimal = conf.decimalTypeEnabled,
allowNull = true)

override def convertToGpu(): GpuExec = {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.nvidia.spark.rapids

import ai.rapids.cudf.{ColumnVector, DType, Scalar}

import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.types.{DataType, DecimalType, LongType}

case class GpuUnscaledValue(child: Expression) extends GpuUnaryExpression {
override def dataType: DataType = LongType
override def toString: String = s"UnscaledValue($child)"

override protected def doColumnar(input: GpuColumnVector): ColumnVector = {
withResource(input.getBase.logicalCastTo(DType.INT64)) { view =>
view.copyToColumnVector()
}
}
}

case class GpuMakeDecimal(
child: Expression,
precision: Int,
sparkScale: Int,
nullOnOverflow: Boolean) extends GpuUnaryExpression {

override def dataType: DataType = DecimalType(precision, sparkScale)
override def nullable: Boolean = child.nullable || nullOnOverflow
override def toString: String = s"MakeDecimal($child,$precision,$sparkScale)"

private lazy val cudfScale = -sparkScale
private lazy val maxValue = BigDecimal(("9"*precision) + "e" + cudfScale.toString)
.bigDecimal.unscaledValue().longValue()

override protected def doColumnar(input: GpuColumnVector): ColumnVector = {
val base = input.getBase
val outputType = DType.create(DType.DTypeEnum.DECIMAL64, cudfScale)
if (nullOnOverflow) {
val overflowed = withResource(Scalar.fromLong(maxValue)) { limit =>
base.greaterThan(limit)
}
withResource(overflowed) { overflowed =>
withResource(Scalar.fromNull(outputType)) { nullVal =>
withResource(base.logicalCastTo(outputType)) { view =>
overflowed.ifElse(nullVal, view)
}
}
}
} else {
withResource(base.logicalCastTo(outputType)) { view =>
view.copyToColumnVector()
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ import org.apache.spark.serializer.Serializer
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder}
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, SinglePartition}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec}
import org.apache.spark.sql.execution.metric._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids.{GpuShuffleDependency, GpuShuffleEnv}
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.{DataType, DecimalType}
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.MutablePair

Expand All @@ -55,10 +55,22 @@ class GpuShuffleMeta(
// when AQE is enabled and we are planning a new query stage, we need to look at meta-data
// previously stored on the spark plan to determine whether this exchange can run on GPU
wrapped.getTagValue(gpuSupportedTag).foreach(_.foreach(willNotWorkOnGpu))

val hasDec = shuffle.schema.fields.map(_.dataType).exists(_.isInstanceOf[DecimalType])
if (hasDec) {
shuffle.outputPartitioning match {
case SinglePartition => // OK
// case _: HashPartitioning => //Hash Partitioning on decimal corrupts data
// https://github.com/rapidsai/cudf/issues/6996
case _: RangePartitioning => // OK
case o => willNotWorkOnGpu(s"Decimal for $o is not supported right now")
}
}
}

override def isSupportedType(t: DataType): Boolean =
GpuOverrides.isSupportedType(t,
allowDecimal = conf.decimalTypeEnabled,
allowNull = true)

override def convertToGpu(): GpuExec =
Expand Down

0 comments on commit 5c43fd3

Please sign in to comment.