From b1cfb50a804ed75684b9f35c2eaa45091ebc3c55 Mon Sep 17 00:00:00 2001
From: Remzi Yang <59198230+HaoYang670@users.noreply.github.com>
Date: Wed, 6 Apr 2022 22:38:22 +0800
Subject: [PATCH] Add support for all types (#5132)
Signed-off-by: remzi <13716567376yh@gmail.com>
---
docs/supported_ops.md | 36 +++++++++----------
.../src/main/python/hash_aggregate_test.py | 2 +-
.../nvidia/spark/rapids/GpuOverrides.scala | 4 +--
3 files changed, 20 insertions(+), 22 deletions(-)
diff --git a/docs/supported_ops.md b/docs/supported_ops.md
index abb928b640c..eac820bb48a 100644
--- a/docs/supported_ops.md
+++ b/docs/supported_ops.md
@@ -15130,12 +15130,12 @@ are limited.
S |
S |
S |
-NS |
-NS |
-NS |
-NS |
-PS UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT |
-NS |
+S |
+S |
+PS UTC is only supported TZ for child TIMESTAMP |
+PS UTC is only supported TZ for child TIMESTAMP |
+PS UTC is only supported TZ for child TIMESTAMP |
+S |
result |
@@ -15173,12 +15173,12 @@ are limited.
S |
S |
S |
-NS |
-NS |
-NS |
-NS |
-PS UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT |
-NS |
+S |
+S |
+PS UTC is only supported TZ for child TIMESTAMP |
+PS UTC is only supported TZ for child TIMESTAMP |
+PS UTC is only supported TZ for child TIMESTAMP |
+S |
result |
@@ -15216,12 +15216,12 @@ are limited.
S |
S |
S |
-NS |
-NS |
-NS |
-NS |
-PS UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT |
-NS |
+S |
+S |
+PS UTC is only supported TZ for child TIMESTAMP |
+PS UTC is only supported TZ for child TIMESTAMP |
+PS UTC is only supported TZ for child TIMESTAMP |
+S |
result |
diff --git a/integration_tests/src/main/python/hash_aggregate_test.py b/integration_tests/src/main/python/hash_aggregate_test.py
index 4ab98da2eae..8f2b3af984e 100644
--- a/integration_tests/src/main/python/hash_aggregate_test.py
+++ b/integration_tests/src/main/python/hash_aggregate_test.py
@@ -1051,7 +1051,7 @@ def test_generic_reductions(data_gen):
'count(1)'),
conf = local_conf)
-@pytest.mark.parametrize('data_gen', non_nan_all_basic_gens, ids=idfn)
+@pytest.mark.parametrize('data_gen', all_gen + _nested_gens, ids=idfn)
def test_count(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen) \
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
index daf8d56f33d..bc007643981 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
@@ -2245,9 +2245,7 @@ object GpuOverrides extends Logging {
ExprChecks.fullAgg(
TypeSig.LONG, TypeSig.LONG,
repeatingParamCheck = Some(RepeatingParamCheck(
- "input", _gpuCommonTypes + TypeSig.DECIMAL_128 +
- TypeSig.STRUCT.nested(_gpuCommonTypes + TypeSig.DECIMAL_128),
- TypeSig.all))),
+ "input", TypeSig.all, TypeSig.all))),
(count, conf, p, r) => new AggExprMeta[Count](count, conf, p, r) {
override def tagAggForGpu(): Unit = {
if (count.children.size > 1) {