diff --git a/docs/configs.md b/docs/configs.md
index 85ababd5425..773d2413640 100644
--- a/docs/configs.md
+++ b/docs/configs.md
@@ -176,6 +176,7 @@ Name | SQL Function(s) | Description | Default Value | Notes
spark.rapids.sql.expression.Log2|`log2`|Log base 2|true|None|
spark.rapids.sql.expression.Logarithm|`log`|Log variable base|true|None|
spark.rapids.sql.expression.Lower|`lower`, `lcase`|String lowercase operator|false|This is not 100% compatible with the Spark version because in some cases unicode characters change byte width when changing the case. The GPU string conversion does not support these characters. For a full list of unsupported characters see https://github.com/rapidsai/cudf/issues/3132|
+spark.rapids.sql.expression.MakeDecimal| |Create a Decimal from an unscaled long value form some aggregation optimizations|true|None|
spark.rapids.sql.expression.Md5|`md5`|MD5 hash operator|true|None|
spark.rapids.sql.expression.Minute|`minute`|Returns the minute component of the string/timestamp|true|None|
spark.rapids.sql.expression.MonotonicallyIncreasingID|`monotonically_increasing_id`|Returns monotonically increasing 64-bit integers|true|None|
@@ -228,6 +229,7 @@ Name | SQL Function(s) | Description | Default Value | Notes
spark.rapids.sql.expression.UnboundedFollowing$| |Special boundary for a window frame, indicating all rows preceding the current row|true|None|
spark.rapids.sql.expression.UnboundedPreceding$| |Special boundary for a window frame, indicating all rows preceding the current row|true|None|
spark.rapids.sql.expression.UnixTimestamp|`unix_timestamp`|Returns the UNIX timestamp of current or specified time|true|None|
+spark.rapids.sql.expression.UnscaledValue| |Convert a Decimal to an unscaled long value for some aggregation optimizations|true|None|
spark.rapids.sql.expression.Upper|`upper`, `ucase`|String uppercase operator|false|This is not 100% compatible with the Spark version because in some cases unicode characters change byte width when changing the case. The GPU string conversion does not support these characters. For a full list of unsupported characters see https://github.com/rapidsai/cudf/issues/3132|
spark.rapids.sql.expression.WeekDay|`weekday`|Returns the day of the week (0 = Monday...6=Sunday)|true|None|
spark.rapids.sql.expression.WindowExpression| |Calculates a return value for every input row of a table based on a group (or "window") of rows|true|None|
diff --git a/integration_tests/src/main/python/data_gen.py b/integration_tests/src/main/python/data_gen.py
index 6074f8cf827..f3b27bb4b9f 100644
--- a/integration_tests/src/main/python/data_gen.py
+++ b/integration_tests/src/main/python/data_gen.py
@@ -28,6 +28,8 @@ class DataGen:
"""Base class for data generation"""
def __repr__(self):
+ if not self.nullable:
+ return self.__class__.__name__[:-3] + '(not_null)'
return self.__class__.__name__[:-3]
def __hash__(self):
@@ -722,6 +724,7 @@ def gen_scalars_for_sql(data_gen, count, seed=0, force_no_nulls=False):
decimal_gen_neg_scale = DecimalGen(precision=7, scale=-3)
decimal_gen_scale_precision = DecimalGen(precision=7, scale=3)
decimal_gen_same_scale_precision = DecimalGen(precision=7, scale=7)
+decimal_gen_64bit = DecimalGen(precision=12, scale=2)
null_gen = NullGen()
@@ -734,7 +737,7 @@ def gen_scalars_for_sql(data_gen, count, seed=0, force_no_nulls=False):
double_n_long_gens = [double_gen, long_gen]
int_n_long_gens = [int_gen, long_gen]
decimal_gens = [decimal_gen_default, decimal_gen_neg_scale, decimal_gen_scale_precision,
- decimal_gen_same_scale_precision]
+ decimal_gen_same_scale_precision, decimal_gen_64bit]
# all of the basic gens
all_basic_gens = [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,
diff --git a/integration_tests/src/main/python/hash_aggregate_test.py b/integration_tests/src/main/python/hash_aggregate_test.py
index cf72e338cc4..0753ed4aee3 100644
--- a/integration_tests/src/main/python/hash_aggregate_test.py
+++ b/integration_tests/src/main/python/hash_aggregate_test.py
@@ -162,7 +162,7 @@ def get_params(init_list, marked_params=[]):
# Pytest marker for list of operators allowed to run on the CPU,
# esp. useful in partial and final only modes.
_excluded_operators_marker = pytest.mark.allow_non_gpu(
- 'HashAggregateExec', 'AggregateExpression',
+ 'HashAggregateExec', 'AggregateExpression', 'UnscaledValue', 'MakeDecimal',
'AttributeReference', 'Alias', 'Sum', 'Count', 'Max', 'Min', 'Average', 'Cast',
'KnownFloatingPointNormalized', 'NormalizeNaNAndZero', 'GreaterThan', 'Literal', 'If',
'EqualTo', 'First', 'SortAggregateExec', 'Coalesce')
@@ -173,10 +173,18 @@ def get_params(init_list, marked_params=[]):
]
+_grpkey_small_decimals = [
+ ('a', RepeatSeqGen(DecimalGen(precision=7, scale=3, nullable=(True, 10.0)), length=50)),
+ ('b', DecimalGen(precision=5, scale=2)),
+ ('c', DecimalGen(precision=8, scale=3))]
+
+_init_list_no_nans_with_decimal = _init_list_no_nans + [
+ _grpkey_small_decimals]
+
@approximate_float
@ignore_order
@incompat
-@pytest.mark.parametrize('data_gen', _init_list_no_nans, ids=idfn)
+@pytest.mark.parametrize('data_gen', _init_list_no_nans_with_decimal, ids=idfn)
@pytest.mark.parametrize('conf', get_params(_confs, params_markers_for_confs), ids=idfn)
def test_hash_grpby_sum(data_gen, conf):
assert_gpu_and_cpu_are_equal_collect(
diff --git a/integration_tests/src/main/python/repart_test.py b/integration_tests/src/main/python/repart_test.py
index 228c7fb3cf4..43977ae7a0b 100644
--- a/integration_tests/src/main/python/repart_test.py
+++ b/integration_tests/src/main/python/repart_test.py
@@ -41,6 +41,7 @@ def test_coalesce_df(num_parts, length):
@ignore_order(local=True) # To avoid extra data shuffle by 'sort on Spark' for this repartition test.
def test_repartion_df(num_parts, length):
#This should change eventually to be more than just the basic gens
- gen_list = [('_c' + str(i), gen) for i, gen in enumerate(all_basic_gens)]
+ gen_list = [('_c' + str(i), gen) for i, gen in enumerate(all_basic_gens + decimal_gens)]
assert_gpu_and_cpu_are_equal_collect(
- lambda spark : gen_df(spark, gen_list, length=length).repartition(num_parts))
+ lambda spark : gen_df(spark, gen_list, length=length).repartition(num_parts),
+ conf = allow_negative_scale_of_decimal_conf)
diff --git a/integration_tests/src/main/python/sort_test.py b/integration_tests/src/main/python/sort_test.py
index a90d61acb42..a98f82a1da8 100644
--- a/integration_tests/src/main/python/sort_test.py
+++ b/integration_tests/src/main/python/sort_test.py
@@ -20,36 +20,32 @@
from pyspark.sql.types import *
import pyspark.sql.functions as f
-orderable_gen_classes = [ByteGen, ShortGen, IntegerGen, LongGen, FloatGen, DoubleGen,
- BooleanGen, TimestampGen, DateGen, StringGen, NullGen]
+orderable_not_null_gen = [ByteGen(nullable=False), ShortGen(nullable=False), IntegerGen(nullable=False),
+ LongGen(nullable=False), FloatGen(nullable=False), DoubleGen(nullable=False), BooleanGen(nullable=False),
+ TimestampGen(nullable=False), DateGen(nullable=False), StringGen(nullable=False), DecimalGen(nullable=False),
+ DecimalGen(precision=7, scale=-3, nullable=False), DecimalGen(precision=7, scale=3, nullable=False),
+ DecimalGen(precision=7, scale=7, nullable=False), DecimalGen(precision=12, scale=2, nullable=False)]
-@pytest.mark.parametrize('data_gen_class', orderable_gen_classes, ids=idfn)
-@pytest.mark.parametrize('nullable', [True, False], ids=idfn)
+@pytest.mark.parametrize('data_gen', orderable_gens + orderable_not_null_gen, ids=idfn)
@pytest.mark.parametrize('order', [f.col('a').asc(), f.col('a').asc_nulls_last(), f.col('a').desc(), f.col('a').desc_nulls_first()], ids=idfn)
-def test_single_orderby(data_gen_class, nullable, order):
- if (data_gen_class == NullGen):
- data_gen = data_gen_class()
- else:
- data_gen = data_gen_class(nullable=nullable)
+def test_single_orderby(data_gen, order):
assert_gpu_and_cpu_are_equal_collect(
- lambda spark : unary_op_df(spark, data_gen).orderBy(order))
+ lambda spark : unary_op_df(spark, data_gen).orderBy(order),
+ conf = allow_negative_scale_of_decimal_conf)
-@pytest.mark.parametrize('data_gen_class', orderable_gen_classes, ids=idfn)
-@pytest.mark.parametrize('nullable', [True, False], ids=idfn)
+@pytest.mark.parametrize('data_gen', orderable_gens + orderable_not_null_gen, ids=idfn)
@pytest.mark.parametrize('order', [f.col('a').asc(), f.col('a').asc_nulls_last(), f.col('a').desc(), f.col('a').desc_nulls_first()], ids=idfn)
-def test_single_sort_in_part(data_gen_class, nullable, order):
- if (data_gen_class == NullGen):
- data_gen = data_gen_class()
- else:
- data_gen = data_gen_class(nullable=nullable)
+def test_single_sort_in_part(data_gen, order):
assert_gpu_and_cpu_are_equal_collect(
- lambda spark : unary_op_df(spark, data_gen).sortWithinPartitions(order))
+ lambda spark : unary_op_df(spark, data_gen).sortWithinPartitions(order),
+ conf = allow_negative_scale_of_decimal_conf)
orderable_gens_sort = [byte_gen, short_gen, int_gen, long_gen,
pytest.param(float_gen, marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/84')),
pytest.param(double_gen, marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/84')),
- boolean_gen, timestamp_gen, date_gen, string_gen, null_gen]
+ boolean_gen, timestamp_gen, date_gen, string_gen, null_gen] + decimal_gens
@pytest.mark.parametrize('data_gen', orderable_gens_sort, ids=idfn)
def test_multi_orderby(data_gen):
assert_gpu_and_cpu_are_equal_collect(
- lambda spark : binary_op_df(spark, data_gen).orderBy(f.col('a'), f.col('b').desc()))
+ lambda spark : binary_op_df(spark, data_gen).orderBy(f.col('a'), f.col('b').desc()),
+ conf = allow_negative_scale_of_decimal_conf)
diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java
index 4dc57138f67..e4bea4328bc 100644
--- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java
+++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java
@@ -39,6 +39,59 @@
*/
public class GpuColumnVector extends GpuColumnVectorBase {
+ /**
+ * Print to standard error the contents of a table. Note that this should never be
+ * called from production code, as it is very slow. Also note that this is not production
+ * code. You might need/want to update how the data shows up or add in support for more
+ * types as this really is just for debugging.
+ * @param name the name of the table to print out.
+ * @param table the table to print out.
+ */
+ public static synchronized void debug(String name, Table table) {
+ System.err.println("DEBUG " + name + " " + table);
+ for (int col = 0; col < table.getNumberOfColumns(); col++) {
+ debug(String.valueOf(col), table.getColumn(col));
+ }
+ }
+
+ /**
+ * Print to standard error the contents of a column. Note that this should never be
+ * called from production code, as it is very slow. Also note that this is not production
+ * code. You might need/want to update how the data shows up or add in support for more
+ * types as this really is just for debugging.
+ * @param name the name of the column to print out.
+ * @param col the column to print out.
+ */
+ public static synchronized void debug(String name, ai.rapids.cudf.ColumnVector col) {
+ try (HostColumnVector hostCol = col.copyToHost()) {
+ debug(name, hostCol);
+ }
+ }
+
+ /**
+ * Print to standard error the contents of a column. Note that this should never be
+ * called from production code, as it is very slow. Also note that this is not production
+ * code. You might need/want to update how the data shows up or add in support for more
+ * types as this really is just for debugging.
+ * @param name the name of the column to print out.
+ * @param hostCol the column to print out.
+ */
+ public static synchronized void debug(String name, HostColumnVector hostCol) {
+ DType type = hostCol.getType();
+ System.err.println("COLUMN " + name + " " + type);
+ if (type.getTypeId() == DType.DTypeEnum.DECIMAL64) {
+ for (int i = 0; i < hostCol.getRowCount(); i++) {
+ if (hostCol.isNull(i)) {
+ System.err.println(i + " NULL");
+ } else {
+ System.err.println(i + " " + hostCol.getBigDecimal(i));
+ }
+ }
+ } else {
+ System.err.println("TYPE " + type + " NOT SUPPORTED FOR DEBUG PRINT");
+ }
+ }
+
private static HostColumnVector.DataType convertFrom(DataType spark, boolean nullable) {
if (spark instanceof ArrayType) {
ArrayType arrayType = (ArrayType) spark;
diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/RapidsHostColumnVectorCore.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/RapidsHostColumnVectorCore.java
index efdbabe2088..4a64f315eae 100644
--- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/RapidsHostColumnVectorCore.java
+++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/RapidsHostColumnVectorCore.java
@@ -164,7 +164,8 @@ public final ColumnarMap getMap(int ordinal) {
public final Decimal getDecimal(int rowId, int precision, int scale) {
assert precision <= DType.DECIMAL64_MAX_PRECISION : "Assert " + precision + " <= DECIMAL64_MAX_PRECISION(" + DType.DECIMAL64_MAX_PRECISION + ")";
assert cudfCv.getType().getTypeId() == DType.DTypeEnum.DECIMAL64: "Assert DType to be DECIMAL64";
- assert scale == -cudfCv.getType().getScale() : "Assert fetch decimal with its original scale";
+ assert scale == -cudfCv.getType().getScale() :
+ "Assert fetch decimal with its original scale " + scale + " expected " + (-cudfCv.getType().getScale());
return Decimal.createUnsafe(cudfCv.getLong(rowId), precision, scale);
}
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
index faa6aea3b42..8d6757e5e1b 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
@@ -1477,7 +1477,9 @@ object GpuOverrides {
}
override def isSupportedType(t: DataType): Boolean =
GpuOverrides.isSupportedType(t,
+ allowDecimal = conf.decimalTypeEnabled,
allowNull = true)
+
override def convertToGpu(): GpuExpression = {
// handle the case AggregateExpression has the resultIds parameter where its
// Seq[ExprIds] instead of single ExprId.
@@ -1501,7 +1503,9 @@ object GpuOverrides {
a.withNewChildren(childExprs.map(_.convertToGpu()))
override def isSupportedType(t: DataType): Boolean =
- GpuOverrides.isSupportedType(t, allowNull = true)
+ GpuOverrides.isSupportedType(t,
+ allowDecimal = conf.decimalTypeEnabled,
+ allowNull = true)
}),
expr[Count](
"Count aggregate operator",
@@ -1871,6 +1875,25 @@ object GpuOverrides {
"String character length",
(a, conf, p, r) => new UnaryExprMeta[Length](a, conf, p, r) {
override def convertToGpu(child: Expression): GpuExpression = GpuLength(child)
+ }),
+ expr[UnscaledValue](
+ "Convert a Decimal to an unscaled long value for some aggregation optimizations",
+ (a, conf, p, r) => new UnaryExprMeta[UnscaledValue](a, conf, p, r) {
+ override def convertToGpu(child: Expression): GpuExpression = GpuUnscaledValue(child)
+
+ override def isSupportedType(t: DataType): Boolean =
+ GpuOverrides.isSupportedType(t,
+ allowDecimal = conf.decimalTypeEnabled)
+ }),
+ expr[MakeDecimal](
+ "Create a Decimal from an unscaled long value form some aggregation optimizations",
+ (a, conf, p, r) => new UnaryExprMeta[MakeDecimal](a, conf, p, r) {
+ override def convertToGpu(child: Expression): GpuExpression =
+ GpuMakeDecimal(child, a.precision, a.scale, a.nullOnOverflow)
+
+ override def isSupportedType(t: DataType): Boolean =
+ GpuOverrides.isSupportedType(t,
+ allowDecimal = conf.decimalTypeEnabled)
})
).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSortExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSortExec.scala
index 2af10fab32a..7fbd187dedf 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSortExec.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSortExec.scala
@@ -43,6 +43,7 @@ class GpuSortMeta(
override def isSupportedType(t: DataType): Boolean =
GpuOverrides.isSupportedType(t,
+ allowDecimal = conf.decimalTypeEnabled,
allowNull = true)
override def tagPlanForGpu(): Unit = {
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala
index eca2cace297..bf8bc63b2b5 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala
@@ -101,6 +101,7 @@ class GpuHashAggregateMeta(
override def isSupportedType(t: DataType): Boolean =
GpuOverrides.isSupportedType(t,
allowNull = true,
+ allowDecimal = conf.decimalTypeEnabled,
allowStringMaps = true)
override def tagPlanForGpu(): Unit = {
@@ -270,6 +271,7 @@ class GpuSortAggregateMeta(
override def isSupportedType(t: DataType): Boolean =
GpuOverrides.isSupportedType(t,
+ allowDecimal = conf.decimalTypeEnabled,
allowNull = true)
override def convertToGpu(): GpuExec = {
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/decimalExpressions.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/decimalExpressions.scala
new file mode 100644
index 00000000000..d9623fd9f6c
--- /dev/null
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/decimalExpressions.scala
@@ -0,0 +1,68 @@
+/*
+ * Copyright (c) 2020, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.nvidia.spark.rapids
+
+import ai.rapids.cudf.{ColumnVector, DType, Scalar}
+
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.types.{DataType, DecimalType, LongType}
+
+case class GpuUnscaledValue(child: Expression) extends GpuUnaryExpression {
+ override def dataType: DataType = LongType
+ override def toString: String = s"UnscaledValue($child)"
+
+ override protected def doColumnar(input: GpuColumnVector): ColumnVector = {
+ withResource(input.getBase.logicalCastTo(DType.INT64)) { view =>
+ view.copyToColumnVector()
+ }
+ }
+}
+
+case class GpuMakeDecimal(
+ child: Expression,
+ precision: Int,
+ sparkScale: Int,
+ nullOnOverflow: Boolean) extends GpuUnaryExpression {
+
+ override def dataType: DataType = DecimalType(precision, sparkScale)
+ override def nullable: Boolean = child.nullable || nullOnOverflow
+ override def toString: String = s"MakeDecimal($child,$precision,$sparkScale)"
+
+ private lazy val cudfScale = -sparkScale
+ private lazy val maxValue = BigDecimal(("9"*precision) + "e" + cudfScale.toString)
+ .bigDecimal.unscaledValue().longValue()
+
+ override protected def doColumnar(input: GpuColumnVector): ColumnVector = {
+ val base = input.getBase
+ val outputType = DType.create(DType.DTypeEnum.DECIMAL64, cudfScale)
+ if (nullOnOverflow) {
+ val overflowed = withResource(Scalar.fromLong(maxValue)) { limit =>
+ base.greaterThan(limit)
+ }
+ withResource(overflowed) { overflowed =>
+ withResource(Scalar.fromNull(outputType)) { nullVal =>
+ withResource(base.logicalCastTo(outputType)) { view =>
+ overflowed.ifElse(nullVal, view)
+ }
+ }
+ }
+ } else {
+ withResource(base.logicalCastTo(outputType)) { view =>
+ view.copyToColumnVector()
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExec.scala
index a93f9461279..50d1ac4e82f 100644
--- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExec.scala
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExec.scala
@@ -29,13 +29,13 @@ import org.apache.spark.serializer.Serializer
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder}
-import org.apache.spark.sql.catalyst.plans.physical.Partitioning
+import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, SinglePartition}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec}
import org.apache.spark.sql.execution.metric._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids.{GpuShuffleDependency, GpuShuffleEnv}
-import org.apache.spark.sql.types.DataType
+import org.apache.spark.sql.types.{DataType, DecimalType}
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.MutablePair
@@ -55,10 +55,22 @@ class GpuShuffleMeta(
// when AQE is enabled and we are planning a new query stage, we need to look at meta-data
// previously stored on the spark plan to determine whether this exchange can run on GPU
wrapped.getTagValue(gpuSupportedTag).foreach(_.foreach(willNotWorkOnGpu))
+
+ val hasDec = shuffle.schema.fields.map(_.dataType).exists(_.isInstanceOf[DecimalType])
+ if (hasDec) {
+ shuffle.outputPartitioning match {
+ case SinglePartition => // OK
+// case _: HashPartitioning => //Hash Partitioning on decimal corrupts data
+ // https://github.com/rapidsai/cudf/issues/6996
+ case _: RangePartitioning => // OK
+ case o => willNotWorkOnGpu(s"Decimal for $o is not supported right now")
+ }
+ }
}
override def isSupportedType(t: DataType): Boolean =
GpuOverrides.isSupportedType(t,
+ allowDecimal = conf.decimalTypeEnabled,
allowNull = true)
override def convertToGpu(): GpuExec =