From d3f3096dfe7130bf2170b5d30f689d3bb961174e Mon Sep 17 00:00:00 2001 From: "Robert (Bobby) Evans" Date: Mon, 14 Dec 2020 11:18:54 -0600 Subject: [PATCH] Adds in basic support for decimal sort, sum, and some shuffle (#1380) Signed-off-by: Robert (Bobby) Evans --- docs/configs.md | 2 + integration_tests/src/main/python/data_gen.py | 5 +- .../src/main/python/hash_aggregate_test.py | 12 +++- .../src/main/python/repart_test.py | 5 +- .../src/main/python/sort_test.py | 36 +++++----- .../nvidia/spark/rapids/GpuColumnVector.java | 53 +++++++++++++++ .../rapids/RapidsHostColumnVectorCore.java | 3 +- .../nvidia/spark/rapids/GpuOverrides.scala | 25 ++++++- .../com/nvidia/spark/rapids/GpuSortExec.scala | 1 + .../com/nvidia/spark/rapids/aggregate.scala | 2 + .../spark/rapids/decimalExpressions.scala | 68 +++++++++++++++++++ .../execution/GpuShuffleExchangeExec.scala | 16 ++++- 12 files changed, 199 insertions(+), 29 deletions(-) create mode 100644 sql-plugin/src/main/scala/com/nvidia/spark/rapids/decimalExpressions.scala diff --git a/docs/configs.md b/docs/configs.md index 85ababd5425..773d2413640 100644 --- a/docs/configs.md +++ b/docs/configs.md @@ -176,6 +176,7 @@ Name | SQL Function(s) | Description | Default Value | Notes spark.rapids.sql.expression.Log2|`log2`|Log base 2|true|None| spark.rapids.sql.expression.Logarithm|`log`|Log variable base|true|None| spark.rapids.sql.expression.Lower|`lower`, `lcase`|String lowercase operator|false|This is not 100% compatible with the Spark version because in some cases unicode characters change byte width when changing the case. The GPU string conversion does not support these characters. For a full list of unsupported characters see https://github.com/rapidsai/cudf/issues/3132| +spark.rapids.sql.expression.MakeDecimal| |Create a Decimal from an unscaled long value form some aggregation optimizations|true|None| spark.rapids.sql.expression.Md5|`md5`|MD5 hash operator|true|None| spark.rapids.sql.expression.Minute|`minute`|Returns the minute component of the string/timestamp|true|None| spark.rapids.sql.expression.MonotonicallyIncreasingID|`monotonically_increasing_id`|Returns monotonically increasing 64-bit integers|true|None| @@ -228,6 +229,7 @@ Name | SQL Function(s) | Description | Default Value | Notes spark.rapids.sql.expression.UnboundedFollowing$| |Special boundary for a window frame, indicating all rows preceding the current row|true|None| spark.rapids.sql.expression.UnboundedPreceding$| |Special boundary for a window frame, indicating all rows preceding the current row|true|None| spark.rapids.sql.expression.UnixTimestamp|`unix_timestamp`|Returns the UNIX timestamp of current or specified time|true|None| +spark.rapids.sql.expression.UnscaledValue| |Convert a Decimal to an unscaled long value for some aggregation optimizations|true|None| spark.rapids.sql.expression.Upper|`upper`, `ucase`|String uppercase operator|false|This is not 100% compatible with the Spark version because in some cases unicode characters change byte width when changing the case. The GPU string conversion does not support these characters. For a full list of unsupported characters see https://github.com/rapidsai/cudf/issues/3132| spark.rapids.sql.expression.WeekDay|`weekday`|Returns the day of the week (0 = Monday...6=Sunday)|true|None| spark.rapids.sql.expression.WindowExpression| |Calculates a return value for every input row of a table based on a group (or "window") of rows|true|None| diff --git a/integration_tests/src/main/python/data_gen.py b/integration_tests/src/main/python/data_gen.py index 6074f8cf827..f3b27bb4b9f 100644 --- a/integration_tests/src/main/python/data_gen.py +++ b/integration_tests/src/main/python/data_gen.py @@ -28,6 +28,8 @@ class DataGen: """Base class for data generation""" def __repr__(self): + if not self.nullable: + return self.__class__.__name__[:-3] + '(not_null)' return self.__class__.__name__[:-3] def __hash__(self): @@ -722,6 +724,7 @@ def gen_scalars_for_sql(data_gen, count, seed=0, force_no_nulls=False): decimal_gen_neg_scale = DecimalGen(precision=7, scale=-3) decimal_gen_scale_precision = DecimalGen(precision=7, scale=3) decimal_gen_same_scale_precision = DecimalGen(precision=7, scale=7) +decimal_gen_64bit = DecimalGen(precision=12, scale=2) null_gen = NullGen() @@ -734,7 +737,7 @@ def gen_scalars_for_sql(data_gen, count, seed=0, force_no_nulls=False): double_n_long_gens = [double_gen, long_gen] int_n_long_gens = [int_gen, long_gen] decimal_gens = [decimal_gen_default, decimal_gen_neg_scale, decimal_gen_scale_precision, - decimal_gen_same_scale_precision] + decimal_gen_same_scale_precision, decimal_gen_64bit] # all of the basic gens all_basic_gens = [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, diff --git a/integration_tests/src/main/python/hash_aggregate_test.py b/integration_tests/src/main/python/hash_aggregate_test.py index cf72e338cc4..0753ed4aee3 100644 --- a/integration_tests/src/main/python/hash_aggregate_test.py +++ b/integration_tests/src/main/python/hash_aggregate_test.py @@ -162,7 +162,7 @@ def get_params(init_list, marked_params=[]): # Pytest marker for list of operators allowed to run on the CPU, # esp. useful in partial and final only modes. _excluded_operators_marker = pytest.mark.allow_non_gpu( - 'HashAggregateExec', 'AggregateExpression', + 'HashAggregateExec', 'AggregateExpression', 'UnscaledValue', 'MakeDecimal', 'AttributeReference', 'Alias', 'Sum', 'Count', 'Max', 'Min', 'Average', 'Cast', 'KnownFloatingPointNormalized', 'NormalizeNaNAndZero', 'GreaterThan', 'Literal', 'If', 'EqualTo', 'First', 'SortAggregateExec', 'Coalesce') @@ -173,10 +173,18 @@ def get_params(init_list, marked_params=[]): ] +_grpkey_small_decimals = [ + ('a', RepeatSeqGen(DecimalGen(precision=7, scale=3, nullable=(True, 10.0)), length=50)), + ('b', DecimalGen(precision=5, scale=2)), + ('c', DecimalGen(precision=8, scale=3))] + +_init_list_no_nans_with_decimal = _init_list_no_nans + [ + _grpkey_small_decimals] + @approximate_float @ignore_order @incompat -@pytest.mark.parametrize('data_gen', _init_list_no_nans, ids=idfn) +@pytest.mark.parametrize('data_gen', _init_list_no_nans_with_decimal, ids=idfn) @pytest.mark.parametrize('conf', get_params(_confs, params_markers_for_confs), ids=idfn) def test_hash_grpby_sum(data_gen, conf): assert_gpu_and_cpu_are_equal_collect( diff --git a/integration_tests/src/main/python/repart_test.py b/integration_tests/src/main/python/repart_test.py index 228c7fb3cf4..43977ae7a0b 100644 --- a/integration_tests/src/main/python/repart_test.py +++ b/integration_tests/src/main/python/repart_test.py @@ -41,6 +41,7 @@ def test_coalesce_df(num_parts, length): @ignore_order(local=True) # To avoid extra data shuffle by 'sort on Spark' for this repartition test. def test_repartion_df(num_parts, length): #This should change eventually to be more than just the basic gens - gen_list = [('_c' + str(i), gen) for i, gen in enumerate(all_basic_gens)] + gen_list = [('_c' + str(i), gen) for i, gen in enumerate(all_basic_gens + decimal_gens)] assert_gpu_and_cpu_are_equal_collect( - lambda spark : gen_df(spark, gen_list, length=length).repartition(num_parts)) + lambda spark : gen_df(spark, gen_list, length=length).repartition(num_parts), + conf = allow_negative_scale_of_decimal_conf) diff --git a/integration_tests/src/main/python/sort_test.py b/integration_tests/src/main/python/sort_test.py index a90d61acb42..a98f82a1da8 100644 --- a/integration_tests/src/main/python/sort_test.py +++ b/integration_tests/src/main/python/sort_test.py @@ -20,36 +20,32 @@ from pyspark.sql.types import * import pyspark.sql.functions as f -orderable_gen_classes = [ByteGen, ShortGen, IntegerGen, LongGen, FloatGen, DoubleGen, - BooleanGen, TimestampGen, DateGen, StringGen, NullGen] +orderable_not_null_gen = [ByteGen(nullable=False), ShortGen(nullable=False), IntegerGen(nullable=False), + LongGen(nullable=False), FloatGen(nullable=False), DoubleGen(nullable=False), BooleanGen(nullable=False), + TimestampGen(nullable=False), DateGen(nullable=False), StringGen(nullable=False), DecimalGen(nullable=False), + DecimalGen(precision=7, scale=-3, nullable=False), DecimalGen(precision=7, scale=3, nullable=False), + DecimalGen(precision=7, scale=7, nullable=False), DecimalGen(precision=12, scale=2, nullable=False)] -@pytest.mark.parametrize('data_gen_class', orderable_gen_classes, ids=idfn) -@pytest.mark.parametrize('nullable', [True, False], ids=idfn) +@pytest.mark.parametrize('data_gen', orderable_gens + orderable_not_null_gen, ids=idfn) @pytest.mark.parametrize('order', [f.col('a').asc(), f.col('a').asc_nulls_last(), f.col('a').desc(), f.col('a').desc_nulls_first()], ids=idfn) -def test_single_orderby(data_gen_class, nullable, order): - if (data_gen_class == NullGen): - data_gen = data_gen_class() - else: - data_gen = data_gen_class(nullable=nullable) +def test_single_orderby(data_gen, order): assert_gpu_and_cpu_are_equal_collect( - lambda spark : unary_op_df(spark, data_gen).orderBy(order)) + lambda spark : unary_op_df(spark, data_gen).orderBy(order), + conf = allow_negative_scale_of_decimal_conf) -@pytest.mark.parametrize('data_gen_class', orderable_gen_classes, ids=idfn) -@pytest.mark.parametrize('nullable', [True, False], ids=idfn) +@pytest.mark.parametrize('data_gen', orderable_gens + orderable_not_null_gen, ids=idfn) @pytest.mark.parametrize('order', [f.col('a').asc(), f.col('a').asc_nulls_last(), f.col('a').desc(), f.col('a').desc_nulls_first()], ids=idfn) -def test_single_sort_in_part(data_gen_class, nullable, order): - if (data_gen_class == NullGen): - data_gen = data_gen_class() - else: - data_gen = data_gen_class(nullable=nullable) +def test_single_sort_in_part(data_gen, order): assert_gpu_and_cpu_are_equal_collect( - lambda spark : unary_op_df(spark, data_gen).sortWithinPartitions(order)) + lambda spark : unary_op_df(spark, data_gen).sortWithinPartitions(order), + conf = allow_negative_scale_of_decimal_conf) orderable_gens_sort = [byte_gen, short_gen, int_gen, long_gen, pytest.param(float_gen, marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/84')), pytest.param(double_gen, marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/84')), - boolean_gen, timestamp_gen, date_gen, string_gen, null_gen] + boolean_gen, timestamp_gen, date_gen, string_gen, null_gen] + decimal_gens @pytest.mark.parametrize('data_gen', orderable_gens_sort, ids=idfn) def test_multi_orderby(data_gen): assert_gpu_and_cpu_are_equal_collect( - lambda spark : binary_op_df(spark, data_gen).orderBy(f.col('a'), f.col('b').desc())) + lambda spark : binary_op_df(spark, data_gen).orderBy(f.col('a'), f.col('b').desc()), + conf = allow_negative_scale_of_decimal_conf) diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java index 4dc57138f67..e4bea4328bc 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java @@ -39,6 +39,59 @@ */ public class GpuColumnVector extends GpuColumnVectorBase { + /** + * Print to standard error the contents of a table. Note that this should never be + * called from production code, as it is very slow. Also note that this is not production + * code. You might need/want to update how the data shows up or add in support for more + * types as this really is just for debugging. + * @param name the name of the table to print out. + * @param table the table to print out. + */ + public static synchronized void debug(String name, Table table) { + System.err.println("DEBUG " + name + " " + table); + for (int col = 0; col < table.getNumberOfColumns(); col++) { + debug(String.valueOf(col), table.getColumn(col)); + } + } + + /** + * Print to standard error the contents of a column. Note that this should never be + * called from production code, as it is very slow. Also note that this is not production + * code. You might need/want to update how the data shows up or add in support for more + * types as this really is just for debugging. + * @param name the name of the column to print out. + * @param col the column to print out. + */ + public static synchronized void debug(String name, ai.rapids.cudf.ColumnVector col) { + try (HostColumnVector hostCol = col.copyToHost()) { + debug(name, hostCol); + } + } + + /** + * Print to standard error the contents of a column. Note that this should never be + * called from production code, as it is very slow. Also note that this is not production + * code. You might need/want to update how the data shows up or add in support for more + * types as this really is just for debugging. + * @param name the name of the column to print out. + * @param hostCol the column to print out. + */ + public static synchronized void debug(String name, HostColumnVector hostCol) { + DType type = hostCol.getType(); + System.err.println("COLUMN " + name + " " + type); + if (type.getTypeId() == DType.DTypeEnum.DECIMAL64) { + for (int i = 0; i < hostCol.getRowCount(); i++) { + if (hostCol.isNull(i)) { + System.err.println(i + " NULL"); + } else { + System.err.println(i + " " + hostCol.getBigDecimal(i)); + } + } + } else { + System.err.println("TYPE " + type + " NOT SUPPORTED FOR DEBUG PRINT"); + } + } + private static HostColumnVector.DataType convertFrom(DataType spark, boolean nullable) { if (spark instanceof ArrayType) { ArrayType arrayType = (ArrayType) spark; diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/RapidsHostColumnVectorCore.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/RapidsHostColumnVectorCore.java index efdbabe2088..4a64f315eae 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/RapidsHostColumnVectorCore.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/RapidsHostColumnVectorCore.java @@ -164,7 +164,8 @@ public final ColumnarMap getMap(int ordinal) { public final Decimal getDecimal(int rowId, int precision, int scale) { assert precision <= DType.DECIMAL64_MAX_PRECISION : "Assert " + precision + " <= DECIMAL64_MAX_PRECISION(" + DType.DECIMAL64_MAX_PRECISION + ")"; assert cudfCv.getType().getTypeId() == DType.DTypeEnum.DECIMAL64: "Assert DType to be DECIMAL64"; - assert scale == -cudfCv.getType().getScale() : "Assert fetch decimal with its original scale"; + assert scale == -cudfCv.getType().getScale() : + "Assert fetch decimal with its original scale " + scale + " expected " + (-cudfCv.getType().getScale()); return Decimal.createUnsafe(cudfCv.getLong(rowId), precision, scale); } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index faa6aea3b42..8d6757e5e1b 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -1477,7 +1477,9 @@ object GpuOverrides { } override def isSupportedType(t: DataType): Boolean = GpuOverrides.isSupportedType(t, + allowDecimal = conf.decimalTypeEnabled, allowNull = true) + override def convertToGpu(): GpuExpression = { // handle the case AggregateExpression has the resultIds parameter where its // Seq[ExprIds] instead of single ExprId. @@ -1501,7 +1503,9 @@ object GpuOverrides { a.withNewChildren(childExprs.map(_.convertToGpu())) override def isSupportedType(t: DataType): Boolean = - GpuOverrides.isSupportedType(t, allowNull = true) + GpuOverrides.isSupportedType(t, + allowDecimal = conf.decimalTypeEnabled, + allowNull = true) }), expr[Count]( "Count aggregate operator", @@ -1871,6 +1875,25 @@ object GpuOverrides { "String character length", (a, conf, p, r) => new UnaryExprMeta[Length](a, conf, p, r) { override def convertToGpu(child: Expression): GpuExpression = GpuLength(child) + }), + expr[UnscaledValue]( + "Convert a Decimal to an unscaled long value for some aggregation optimizations", + (a, conf, p, r) => new UnaryExprMeta[UnscaledValue](a, conf, p, r) { + override def convertToGpu(child: Expression): GpuExpression = GpuUnscaledValue(child) + + override def isSupportedType(t: DataType): Boolean = + GpuOverrides.isSupportedType(t, + allowDecimal = conf.decimalTypeEnabled) + }), + expr[MakeDecimal]( + "Create a Decimal from an unscaled long value form some aggregation optimizations", + (a, conf, p, r) => new UnaryExprMeta[MakeDecimal](a, conf, p, r) { + override def convertToGpu(child: Expression): GpuExpression = + GpuMakeDecimal(child, a.precision, a.scale, a.nullOnOverflow) + + override def isSupportedType(t: DataType): Boolean = + GpuOverrides.isSupportedType(t, + allowDecimal = conf.decimalTypeEnabled) }) ).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSortExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSortExec.scala index 2af10fab32a..7fbd187dedf 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSortExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSortExec.scala @@ -43,6 +43,7 @@ class GpuSortMeta( override def isSupportedType(t: DataType): Boolean = GpuOverrides.isSupportedType(t, + allowDecimal = conf.decimalTypeEnabled, allowNull = true) override def tagPlanForGpu(): Unit = { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala index eca2cace297..bf8bc63b2b5 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala @@ -101,6 +101,7 @@ class GpuHashAggregateMeta( override def isSupportedType(t: DataType): Boolean = GpuOverrides.isSupportedType(t, allowNull = true, + allowDecimal = conf.decimalTypeEnabled, allowStringMaps = true) override def tagPlanForGpu(): Unit = { @@ -270,6 +271,7 @@ class GpuSortAggregateMeta( override def isSupportedType(t: DataType): Boolean = GpuOverrides.isSupportedType(t, + allowDecimal = conf.decimalTypeEnabled, allowNull = true) override def convertToGpu(): GpuExec = { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/decimalExpressions.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/decimalExpressions.scala new file mode 100644 index 00000000000..d9623fd9f6c --- /dev/null +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/decimalExpressions.scala @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.nvidia.spark.rapids + +import ai.rapids.cudf.{ColumnVector, DType, Scalar} + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.types.{DataType, DecimalType, LongType} + +case class GpuUnscaledValue(child: Expression) extends GpuUnaryExpression { + override def dataType: DataType = LongType + override def toString: String = s"UnscaledValue($child)" + + override protected def doColumnar(input: GpuColumnVector): ColumnVector = { + withResource(input.getBase.logicalCastTo(DType.INT64)) { view => + view.copyToColumnVector() + } + } +} + +case class GpuMakeDecimal( + child: Expression, + precision: Int, + sparkScale: Int, + nullOnOverflow: Boolean) extends GpuUnaryExpression { + + override def dataType: DataType = DecimalType(precision, sparkScale) + override def nullable: Boolean = child.nullable || nullOnOverflow + override def toString: String = s"MakeDecimal($child,$precision,$sparkScale)" + + private lazy val cudfScale = -sparkScale + private lazy val maxValue = BigDecimal(("9"*precision) + "e" + cudfScale.toString) + .bigDecimal.unscaledValue().longValue() + + override protected def doColumnar(input: GpuColumnVector): ColumnVector = { + val base = input.getBase + val outputType = DType.create(DType.DTypeEnum.DECIMAL64, cudfScale) + if (nullOnOverflow) { + val overflowed = withResource(Scalar.fromLong(maxValue)) { limit => + base.greaterThan(limit) + } + withResource(overflowed) { overflowed => + withResource(Scalar.fromNull(outputType)) { nullVal => + withResource(base.logicalCastTo(outputType)) { view => + overflowed.ifElse(nullVal, view) + } + } + } + } else { + withResource(base.logicalCastTo(outputType)) { view => + view.copyToColumnVector() + } + } + } +} \ No newline at end of file diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExec.scala index a93f9461279..50d1ac4e82f 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExec.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExec.scala @@ -29,13 +29,13 @@ import org.apache.spark.serializer.Serializer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder} -import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, SinglePartition} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec} import org.apache.spark.sql.execution.metric._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.rapids.{GpuShuffleDependency, GpuShuffleEnv} -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.{DataType, DecimalType} import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.MutablePair @@ -55,10 +55,22 @@ class GpuShuffleMeta( // when AQE is enabled and we are planning a new query stage, we need to look at meta-data // previously stored on the spark plan to determine whether this exchange can run on GPU wrapped.getTagValue(gpuSupportedTag).foreach(_.foreach(willNotWorkOnGpu)) + + val hasDec = shuffle.schema.fields.map(_.dataType).exists(_.isInstanceOf[DecimalType]) + if (hasDec) { + shuffle.outputPartitioning match { + case SinglePartition => // OK +// case _: HashPartitioning => //Hash Partitioning on decimal corrupts data + // https://github.com/rapidsai/cudf/issues/6996 + case _: RangePartitioning => // OK + case o => willNotWorkOnGpu(s"Decimal for $o is not supported right now") + } + } } override def isSupportedType(t: DataType): Boolean = GpuOverrides.isSupportedType(t, + allowDecimal = conf.decimalTypeEnabled, allowNull = true) override def convertToGpu(): GpuExec =