diff --git a/docs/supported_ops.md b/docs/supported_ops.md index 6d2e06f6295..8c26e9e97c1 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -187,13 +187,13 @@ Accelerator supports are described below. S S* S -NS +S* S NS NS -PS* (missing nested DECIMAL, BINARY, CALENDAR, UDT) -PS* (missing nested DECIMAL, BINARY, CALENDAR, UDT) -PS* (missing nested DECIMAL, BINARY, CALENDAR, UDT) +PS* (missing nested BINARY, CALENDAR, UDT) +PS* (missing nested BINARY, CALENDAR, UDT) +PS* (missing nested BINARY, CALENDAR, UDT) NS @@ -486,13 +486,13 @@ Accelerator supports are described below. S S* S +S* NS NS NS -NS -PS* (missing nested DECIMAL, NULL, BINARY, CALENDAR, UDT) -PS* (missing nested DECIMAL, NULL, BINARY, CALENDAR, UDT) -PS* (missing nested DECIMAL, NULL, BINARY, CALENDAR, UDT) +PS* (missing nested NULL, BINARY, CALENDAR, UDT) +PS* (missing nested NULL, BINARY, CALENDAR, UDT) +PS* (missing nested NULL, BINARY, CALENDAR, UDT) NS @@ -17043,13 +17043,13 @@ dates or timestamps, or for a lack of type coercion support. S S S -NS +S NS -PS (missing nested DECIMAL, BINARY) -PS (missing nested DECIMAL, BINARY) -PS (missing nested DECIMAL, BINARY) +PS (missing nested BINARY) +PS (missing nested BINARY) +PS (missing nested BINARY) Output diff --git a/integration_tests/src/main/python/parquet_test.py b/integration_tests/src/main/python/parquet_test.py index d78e4ef0588..ea784a853d9 100644 --- a/integration_tests/src/main/python/parquet_test.py +++ b/integration_tests/src/main/python/parquet_test.py @@ -27,23 +27,30 @@ def read_parquet_df(data_path): def read_parquet_sql(data_path): return lambda spark : spark.sql('select * from parquet.`{}`'.format(data_path)) + +# Override decimal_gens because decimal with negative scale is unsupported in parquet reading +decimal_gens = [DecimalGen(), DecimalGen(precision=7, scale=3), DecimalGen(precision=10, scale=10), + DecimalGen(precision=9, scale=0), DecimalGen(precision=18, scale=15)] + parquet_gens_list = [[byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, string_gen, boolean_gen, date_gen, TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc)), ArrayGen(byte_gen), ArrayGen(long_gen), ArrayGen(string_gen), ArrayGen(date_gen), ArrayGen(TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))), + ArrayGen(DecimalGen()), ArrayGen(ArrayGen(byte_gen)), - StructGen([['child0', ArrayGen(byte_gen)], ['child1', byte_gen], ['child2', float_gen]]), - ArrayGen(StructGen([['child0', string_gen], ['child1', double_gen], ['child2', int_gen]]))] + map_gens_sample, - pytest.param([timestamp_gen], marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/132'))] + StructGen([['child0', ArrayGen(byte_gen)], ['child1', byte_gen], ['child2', float_gen], ['child3', DecimalGen()]]), + ArrayGen(StructGen([['child0', string_gen], ['child1', double_gen], ['child2', int_gen]]))] + + map_gens_sample + decimal_gens, + pytest.param([timestamp_gen], marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/132'))] # test with original parquet file reader, the multi-file parallel reader for cloud, and coalesce file reader for # non-cloud -original_parquet_file_reader_conf={'spark.rapids.sql.format.parquet.reader.type': 'PERFILE'} -multithreaded_parquet_file_reader_conf={'spark.rapids.sql.format.parquet.reader.type': 'MULTITHREADED'} -coalesce_parquet_file_reader_conf={'spark.rapids.sql.format.parquet.reader.type': 'COALESCING'} +original_parquet_file_reader_conf = {'spark.rapids.sql.format.parquet.reader.type': 'PERFILE'} +multithreaded_parquet_file_reader_conf = {'spark.rapids.sql.format.parquet.reader.type': 'MULTITHREADED'} +coalesce_parquet_file_reader_conf = {'spark.rapids.sql.format.parquet.reader.type': 'COALESCING'} reader_opt_confs = [original_parquet_file_reader_conf, multithreaded_parquet_file_reader_conf, - coalesce_parquet_file_reader_conf] + coalesce_parquet_file_reader_conf] @pytest.mark.parametrize('parquet_gens', parquet_gens_list, ids=idfn) @pytest.mark.parametrize('read_func', [read_parquet_df, read_parquet_sql]) @@ -66,9 +73,9 @@ def test_read_round_trip(spark_tmp_path, parquet_gens, read_func, reader_confs, @pytest.mark.parametrize('read_func', [read_parquet_df, read_parquet_sql]) @pytest.mark.parametrize('disable_conf', ['spark.rapids.sql.format.parquet.enabled', 'spark.rapids.sql.format.parquet.read.enabled']) def test_parquet_fallback(spark_tmp_path, read_func, disable_conf): - data_gens =[string_gen, - byte_gen, short_gen, int_gen, long_gen, boolean_gen] - + data_gens = [string_gen, + byte_gen, short_gen, int_gen, long_gen, boolean_gen] + decimal_gens + gen_list = [('_c' + str(i), gen) for i, gen in enumerate(data_gens)] gen = StructGen(gen_list, nullable=False) data_path = spark_tmp_path + '/PARQUET_DATA' @@ -103,8 +110,8 @@ def test_compress_read_round_trip(spark_tmp_path, compress, v1_enabled_list, rea byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, boolean_gen, string_gen, date_gen, # Once https://github.com/NVIDIA/spark-rapids/issues/132 is fixed replace this with - # timestamp_gen - TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))] + # timestamp_gen + TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))] + decimal_gens @pytest.mark.parametrize('parquet_gen', parquet_pred_push_gens, ids=idfn) @pytest.mark.parametrize('read_func', [read_parquet_df, read_parquet_sql]) @@ -193,11 +200,27 @@ def test_ts_read_fails_datetime_legacy(gen, spark_tmp_path, ts_write, ts_rebase, lambda spark : readParquetCatchException(spark, data_path), conf=all_confs) + +@pytest.mark.parametrize('parquet_gens', [decimal_gens], ids=idfn) +@pytest.mark.parametrize('read_func', [read_parquet_df, read_parquet_sql]) +@pytest.mark.parametrize('reader_confs', reader_opt_confs) +@pytest.mark.parametrize('v1_enabled_list', ["", "parquet"]) +def test_decimal_read_legacy(spark_tmp_path, parquet_gens, read_func, reader_confs, v1_enabled_list): + gen_list = [('_c' + str(i), gen) for i, gen in enumerate(parquet_gens)] + data_path = spark_tmp_path + '/PARQUET_DATA' + with_cpu_session( + lambda spark : gen_df(spark, gen_list).write.parquet(data_path), + conf={'spark.sql.parquet.writeLegacyFormat': 'true'}) + all_confs = reader_confs.copy() + all_confs.update({'spark.sql.sources.useV1SourceList': v1_enabled_list}) + assert_gpu_and_cpu_are_equal_collect(read_func(data_path), conf=all_confs) + + parquet_gens_legacy_list = [[byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, - string_gen, boolean_gen, DateGen(start=date(1590, 1, 1)), - TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))], - pytest.param([timestamp_gen], marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/133')), - pytest.param([date_gen], marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/133'))] + string_gen, boolean_gen, DateGen(start=date(1590, 1, 1)), + TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))] + decimal_gens, + pytest.param([timestamp_gen], marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/133')), + pytest.param([date_gen], marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/133'))] @pytest.mark.parametrize('parquet_gens', parquet_gens_legacy_list, ids=idfn) @pytest.mark.parametrize('reader_confs', reader_opt_confs) @@ -221,7 +244,7 @@ def test_simple_partitioned_read(spark_tmp_path, v1_enabled_list, reader_confs): # we should go with a more standard set of generators parquet_gens = [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, string_gen, boolean_gen, DateGen(start=date(1590, 1, 1)), - TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))] + TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))] + decimal_gens gen_list = [('_c' + str(i), gen) for i, gen in enumerate(parquet_gens)] first_data_path = spark_tmp_path + '/PARQUET_DATA/key=0/key2=20' with_cpu_session( @@ -295,7 +318,7 @@ def test_read_merge_schema(spark_tmp_path, v1_enabled_list, reader_confs): # we should go with a more standard set of generators parquet_gens = [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, string_gen, boolean_gen, DateGen(start=date(1590, 1, 1)), - TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))] + TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))] + decimal_gens first_gen_list = [('_c' + str(i), gen) for i, gen in enumerate(parquet_gens)] first_data_path = spark_tmp_path + '/PARQUET_DATA/key=0' with_cpu_session( @@ -320,7 +343,7 @@ def test_read_merge_schema_from_conf(spark_tmp_path, v1_enabled_list, reader_con # we should go with a more standard set of generators parquet_gens = [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, string_gen, boolean_gen, DateGen(start=date(1590, 1, 1)), - TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))] + TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))] + decimal_gens first_gen_list = [('_c' + str(i), gen) for i, gen in enumerate(parquet_gens)] first_data_path = spark_tmp_path + '/PARQUET_DATA/key=0' with_cpu_session( @@ -403,15 +426,15 @@ def test_small_file_memory(spark_tmp_path, v1_enabled_list): _nested_pruning_schemas = [ - ([["a", StructGen([["c_1", StringGen()], ["c_2", LongGen()], ["c_3", ShortGen()]])]], + ([["a", StructGen([["c_1", StringGen()], ["c_2", LongGen()], ["c_3", ShortGen()]])]], [["a", StructGen([["c_1", StringGen()]])]]), - ([["a", StructGen([["c_1", StringGen()], ["c_2", LongGen()], ["c_3", ShortGen()]])]], + ([["a", StructGen([["c_1", StringGen()], ["c_2", LongGen()], ["c_3", ShortGen()]])]], [["a", StructGen([["c_2", LongGen()]])]]), - ([["a", StructGen([["c_1", StringGen()], ["c_2", LongGen()], ["c_3", ShortGen()]])]], + ([["a", StructGen([["c_1", StringGen()], ["c_2", LongGen()], ["c_3", ShortGen()]])]], [["a", StructGen([["c_3", ShortGen()]])]]), - ([["a", StructGen([["c_1", StringGen()], ["c_2", LongGen()], ["c_3", ShortGen()]])]], + ([["a", StructGen([["c_1", StringGen()], ["c_2", LongGen()], ["c_3", ShortGen()]])]], [["a", StructGen([["c_1", StringGen()], ["c_3", ShortGen()]])]]), - ([["a", StructGen([["c_1", StringGen()], ["c_2", LongGen()], ["c_3", ShortGen()]])]], + ([["a", StructGen([["c_1", StringGen()], ["c_2", LongGen()], ["c_3", ShortGen()]])]], [["a", StructGen([["c_3", ShortGen()], ["c_2", LongGen()], ["c_1", StringGen()]])]]), ([["ar", ArrayGen(StructGen([["str_1", StringGen()],["str_2", StringGen()]]))]], [["ar", ArrayGen(StructGen([["str_2", StringGen()]]))]]) diff --git a/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala b/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala index 75b3992e5a3..bd413d34100 100644 --- a/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala +++ b/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala @@ -147,8 +147,9 @@ class Spark300Shims extends SparkShims { GpuOverrides.exec[FileSourceScanExec]( "Reading data from files, often from Hive tables", ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.STRUCT + TypeSig.MAP + - TypeSig.ARRAY).nested(), TypeSig.all), + TypeSig.ARRAY + TypeSig.DECIMAL).nested(), TypeSig.all), (fsse, conf, p, r) => new SparkPlanMeta[FileSourceScanExec](fsse, conf, p, r) { + // partition filters and data filters are not run on the GPU override val childExprs: Seq[ExprMeta[_]] = Seq.empty diff --git a/shims/spark310/src/main/scala/com/nvidia/spark/rapids/shims/spark310/Spark310Shims.scala b/shims/spark310/src/main/scala/com/nvidia/spark/rapids/shims/spark310/Spark310Shims.scala index 119e2c9e770..4a8e2e1dab8 100644 --- a/shims/spark310/src/main/scala/com/nvidia/spark/rapids/shims/spark310/Spark310Shims.scala +++ b/shims/spark310/src/main/scala/com/nvidia/spark/rapids/shims/spark310/Spark310Shims.scala @@ -202,7 +202,7 @@ class Spark310Shims extends Spark301Shims { GpuOverrides.exec[FileSourceScanExec]( "Reading data from files, often from Hive tables", ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.STRUCT + TypeSig.MAP + - TypeSig.ARRAY).nested(), TypeSig.all), + TypeSig.ARRAY + TypeSig.DECIMAL).nested(), TypeSig.all), (fsse, conf, p, r) => new SparkPlanMeta[FileSourceScanExec](fsse, conf, p, r) { // partition filters and data filters are not run on the GPU override val childExprs: Seq[ExprMeta[_]] = Seq.empty diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBatchScanExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBatchScanExec.scala index 01687c4a5cf..0b096ddf781 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBatchScanExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBatchScanExec.scala @@ -45,7 +45,7 @@ import org.apache.spark.sql.execution.datasources.v2._ import org.apache.spark.sql.execution.datasources.v2.csv.CSVScan import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{DateType, StructField, StructType, TimestampType} +import org.apache.spark.sql.types.{DateType, DecimalType, StructField, StructType, TimestampType} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.SerializableConfiguration @@ -229,6 +229,10 @@ object GpuCSVScan { } } // TODO parsedOptions.emptyValueInRead + + if (readSchema.exists(_.dataType.isInstanceOf[DecimalType])) { + meta.willNotWorkOnGpu("DecimalType is not supported") + } } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala index e4a0f10359a..cdb14c06aa4 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala @@ -115,7 +115,7 @@ object GpuOrcScanBase { meta.willNotWorkOnGpu("mergeSchema and schema evolution is not supported yet") } schema.foreach { field => - if (!GpuColumnVector.isNonNestedSupportedType(field.dataType)) { + if (!GpuOverrides.isSupportedType(field.dataType)) { meta.willNotWorkOnGpu(s"GpuOrcScan does not support fields of type ${field.dataType}") } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 3429bf178ec..af13ee6cacd 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -41,6 +41,7 @@ import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.execution.datasources.v2.{AlterNamespaceSetPropertiesExec, AlterTableExec, AtomicReplaceTableExec, BatchScanExec, CreateNamespaceExec, CreateTableExec, DeleteFromTableExec, DescribeNamespaceExec, DescribeTableExec, DropNamespaceExec, DropTableExec, RefreshTableExec, RenameTableExec, ReplaceTableExec, SetCatalogAndNamespaceExec, ShowCurrentNamespaceExec, ShowNamespacesExec, ShowTablePropertiesExec, ShowTablesExec} import org.apache.spark.sql.execution.datasources.v2.csv.CSVScan +import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.execution.python._ @@ -2193,7 +2194,8 @@ object GpuOverrides { exec[BatchScanExec]( "The backend for most file input", ExecChecks( - (TypeSig.commonCudfTypes + TypeSig.STRUCT + TypeSig.MAP + TypeSig.ARRAY).nested(), + (TypeSig.commonCudfTypes + TypeSig.STRUCT + TypeSig.MAP + TypeSig.ARRAY + + TypeSig.DECIMAL).nested(), TypeSig.all), (p, conf, parent, r) => new SparkPlanMeta[BatchScanExec](p, conf, parent, r) { override val childScans: scala.Seq[ScanMeta[_]] = diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala index 895d2526b03..0bb577f4e2c 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala @@ -139,7 +139,8 @@ object GpuParquetScanBase { allowMaps = true, allowArray = true, allowStruct = true, - allowNesting = true)) { + allowNesting = true, + allowDecimal = meta.conf.decimalTypeEnabled)) { meta.willNotWorkOnGpu(s"GpuParquetScan does not support fields of type ${field.dataType}") } } @@ -197,6 +198,33 @@ object GpuParquetScanBase { meta.willNotWorkOnGpu(s"$other is not a supported read rebase mode") } } + + private[rapids] def convertDecimal32Columns(t: Table): Table = { + val containDecimal32Column = (0 until t.getNumberOfColumns).exists { i => + t.getColumn(i).getType.getTypeId == DType.DTypeEnum.DECIMAL32 + } + // return input table if there exists no DECIMAL32 columns + if (!containDecimal32Column) return t + + val columns = new Array[ColumnVector](t.getNumberOfColumns) + try { + RebaseHelper.withResource(t) { _ => + (0 until t.getNumberOfColumns).foreach { i => + t.getColumn(i).getType match { + case tpe if tpe.getTypeId == DType.DTypeEnum.DECIMAL32 => + columns(i) = t.getColumn(i).castTo( + DType.create(DType.DTypeEnum.DECIMAL64, tpe.getScale)) + case _ => + columns(i) = t.getColumn(i).incRefCount() + } + } + } + new Table(columns: _*) + } finally { + // clean temporary column vectors + columns.safeClose() + } + } } /** @@ -657,13 +685,16 @@ abstract class FileParquetPartitionReaderBase( inputTable: Table, filePath: String, clippedSchema: MessageType): Table = { - if (readDataSchema.length > inputTable.getNumberOfColumns) { + // Convert Decimal32 columns to Decimal64, because spark-rapids only supports Decimal64. + val inTable = GpuParquetScanBase.convertDecimal32Columns(inputTable) + + if (readDataSchema.length > inTable.getNumberOfColumns) { // Spark+Parquet schema evolution is relatively simple with only adding/removing columns // To type casting or anyting like that val clippedGroups = clippedSchema.asGroupType() val newColumns = new Array[ColumnVector](readDataSchema.length) try { - withResource(inputTable) { table => + withResource(inTable) { table => var readAt = 0 (0 until readDataSchema.length).foreach(writeAt => { val readField = readDataSchema(writeAt) @@ -686,7 +717,7 @@ abstract class FileParquetPartitionReaderBase( newColumns.safeClose() } } else { - inputTable + inTable } } @@ -1115,6 +1146,7 @@ class MultiFileParquetPartitionReader( } val parseOpts = ParquetOptions.builder() .withTimeUnit(DType.TIMESTAMP_MICROSECONDS) + .enableStrictDecimalType(true) .includeColumn(readDataSchema.fieldNames:_*).build() // about to start using the GPU @@ -1523,6 +1555,7 @@ class MultiFileCloudParquetPartitionReader( } val parseOpts = ParquetOptions.builder() .withTimeUnit(DType.TIMESTAMP_MICROSECONDS) + .enableStrictDecimalType(true) .includeColumn(readDataSchema.fieldNames: _*).build() // about to start using the GPU @@ -1658,6 +1691,7 @@ class ParquetPartitionReader( } val parseOpts = ParquetOptions.builder() .withTimeUnit(DType.TIMESTAMP_MICROSECONDS) + .enableStrictDecimalType(true) .includeColumn(readDataSchema.fieldNames:_*).build() // about to start using the GPU diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostColumnarToGpu.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostColumnarToGpu.scala index e9ad8324307..295655972f6 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostColumnarToGpu.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostColumnarToGpu.scala @@ -127,6 +127,7 @@ object HostColumnarToGpu { if (cv.isNullAt(i)) { b.appendNull() } else { + // The precision here matters for cpu column vectors (such as OnHeapColumnVector). b.append(cv.getDecimal(i, dt.precision, dt.scale).toUnscaledLong) } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala index 6cdf7adf8dc..2fd76aef3b9 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala @@ -1236,13 +1236,13 @@ object SupportedOpsDocs { println("S") // DATE println("S") // TIMESTAMP println("S") // STRING - println("NS") // DECIMAL + println("S") // DECIMAL println("") // NULL println("NS") // BINARY println("") // CALENDAR - println("PS (missing nested DECIMAL, BINARY)") // ARRAY - println("PS (missing nested DECIMAL, BINARY)") // MAP - println("PS (missing nested DECIMAL, BINARY)") // STRUCT + println("PS (missing nested BINARY)") // ARRAY + println("PS (missing nested BINARY)") // MAP + println("PS (missing nested BINARY)") // STRUCT println("") println("") println("Output") diff --git a/tests/src/test/resources/decimal-test-legacy.parquet b/tests/src/test/resources/decimal-test-legacy.parquet new file mode 100644 index 00000000000..969f049848f Binary files /dev/null and b/tests/src/test/resources/decimal-test-legacy.parquet differ diff --git a/tests/src/test/resources/decimal-test.csv b/tests/src/test/resources/decimal-test.csv new file mode 100644 index 00000000000..1fd4646904c --- /dev/null +++ b/tests/src/test/resources/decimal-test.csv @@ -0,0 +1,100 @@ +915270249210239718,3232.792,"",771.371173049837,-431710025170174585,4.7239953E24 +50004804273312941,9263.400,0.3815050595,900.890874730220,-2697073954890740236,7.463101E-31 +189216077028719828,8094.536,0.7879423094,817.584127883600,-5133656973475552689,2.4919543E-32 +257886722013221592,7097.760,0.5551530993,263.532647425188,-3917032101531217289,5.176538E7 +223616015091255874,2021.691,0.4300789101,741.876603938996,7810001276519378488,6.673943E9 +672269487046710336,5925.644,0.5119009979,677.273069670187,"",-3.037228E-26 +565739933980371374,2082.917,0.6325578476,307.790881040462,-7858370133784586516,-1.7402571E-20 +30846051239772526,2153.584,"",428.457572476468,2224676899470531349,-3.720298E-28 +292750338676616377,913.012,0.7647953732,329.427928877243,1287002480498025462,963.7641 +863597087074282135,7771.993,0.9378547775,655.942008799151,-4594481394522420980,-5.143567E-37 +740916466628308680,3601.327,0.5253287765,293.314491660678,3663204060287939255,-7.328485E30 +543922875310806589,418.135,0.0089806121,678.725898059834,7100991336799471692,"" +907927561275971206,4357.321,0.0894175775,985.715625304095,8980271347953950594,2.036806E-11 +5365651368778670,28.443,0E-10,98.984414373371,2753461187700946189,"" +98975573402945273,5544.347,0.3277681904,780.987096259467,-2821373815087770381,"" +37654707856660412,3245.419,0.7249053714,"","",5.489915E-15 +417692898406404444,5137.551,0.0654631387,8.906002611085,-3073072854586574157,2.8559989E-5 +30008375870039084,1638.013,0.5689759921,506.427830699217,8001350543913869500,0.0010491271 +267647485487795321,711.758,0.7086510671,787.214975034435,0,-4.294404E35 +32497433138887044,1500.500,0.1632851254,111.862479713755,-2124845351861490171,-2911376.8 +563219648467491123,2533.799,0.0470860444,-999.999999999999,-6150492885085058765,-3.1956E-21 +698244577193394337,7710.613,0.1796242319,177.658429265869,-7293294899695747856,1.55815667E12 +59984767260993922,9598.853,0.9606575721,827.362102857851,7956328026836954420,-2.4386406E11 +185053572335775931,1504.547,0.5687599560,418.547722108074,-8496278226268642122,-1.1115532E-32 +406076905698949206,1584.821,0.9073638746,454.572717017033,9048728756392488113,-3.6374905E-7 +370626360548153088,6979.091,0.7359568576,847.384858615533,4440968842303227453,3.9915457E-28 +880266977576630903,3969.112,0.0536779482,32.241735406630,6086331466042394437,-2.1541503E-33 +914944620826734702,258.240,0.7421725875,413.873958949421,-7351387047220061849,-8.1657964E-36 +615190575294467984,8101.993,0.4021261498,440.941554482483,2251193521986836658,1.7728261E25 +748719415383076949,928.200,0.0701185195,307.250677745230,-255208803392541585,"" +"",8482.048,0.0875250620,889.559939012387,131095898465527538,-0.0 +541457418990604771,298.713,"",356.319195755119,8385038430064196353,-1.1713938E26 +998278945156788331,3126.578,0.0190951652,277.824073520913,5884635015085662155,2.7177546E22 +648685475062397368,955.643,0.4443410992,569.890024734869,-5773983929953240007,3.5561367E32 +128835062166543528,6830.727,"",967.668187303342,-7732708842585057603,-1.69050099E9 +979181842441473467,2627.146,0.0931873144,308.242155826852,6658806232357759343,-3.0718065E22 +373383053238313709,5880.783,0.8089926358,288.091654931223,2449340209240679641,1.8817666E-34 +369917108588895763,"",0.0784246739,419.063742180292,5469773751425577039,6.3262706E-10 +959109271119159898,5963.914,0.5131396358,"",-3786594002235441489,-1.0057181E33 +775063852272099258,0.000,0.1505213451,419.677059264876,-8824564747028895090,5.758003E-13 +495443091472185622,5536.609,0.6390268097,739.197371445563,1760002372390556825,6816057.5 +25146710634437912,3067.206,0.2161656375,215.056698028213,803782761142075607,-54089.32 +613858510238074126,957.490,0.4822075927,518.956895534893,2437026550498167082,-2.4735995E-5 +255794834605228757,899.249,0.8092193887,588.233634837773,-525199667559164571,"" +680825620982324925,5517.827,0.0065746462,193.251744967577,-5041734557262022470,-1485983.4 +714779301951357809,1711.747,0.6136524686,686.957950228997,-6839627708332633308,"" +381520952215527508,5424.226,0.2305893771,549.262462812814,-2059623267661777620,-3.7695377E-10 +669984683600664919,9136.075,0.7850895020,209.124878112725,-2710360188048040119,4.9402314E28 +182598193716752431,2503.031,0.0070287326,3.249783083502,-6633320123827272048,4.1270598E-11 +"",2596.200,0.8936980935,85.288926615679,662259464231680319,"" +40023869813864036,233.012,0.5436629243,171.799250634543,-3664974252843436884,-1.1522237E21 +710322854240541412,4019.609,"",446.549422285354,8834022966795744609,-1.00241495E-20 +841507640518767629,5834.872,0.4100687936,369.576552043052,6075221653337964625,-3.9562905E-11 +857872153237373925,4796.207,0.0756528306,575.779939808894,3893516458893827324,3.1058907E-18 +25775526817610369,548.169,0.0669838937,717.645262020503,-1,-2.9235664E-31 +25215292382715618,"",0.8464787764,632.667903196574,3221645619906578280,3.8987961E-19 +47171113226028385,8168.750,0.9055021741,35.058703401277,8776159873495597953,-3.7015653E-27 +943951645776822973,8443.481,0.3052030122,807.819184192606,-199240735297494389,2.9509407E32 +650796227970751779,6957.306,0.3522180815,680.869699320152,5264124301174436230,-8.20616E-30 +750204382327921920,"",0.9373742004,143.248267456298,-6228871387240662005,7.737559E-19 +612727546233536520,2940.034,0.5992235971,277.240099823954,9223372036854775807,-1.4193973E-5 +248544869063573270,4405.459,0.4938911635,185.838164288601,3780377771042673290,-4.794553E-20 +"",7286.797,0.0525265125,821.424386490011,8226475382227724456,-1.0 +797023704927729407,8307.766,0.3520962355,"",2370991850373366052,-9.2752095E-33 +733656328652598115,9387.470,"",292.371386489063,-772608278816658373,6.0993237 +456487310339748192,"",0.8716752095,759.109542616645,-8516846031062369157,-1.9623711E-20 +12735651342573944,463.904,0.7174409373,999.999999999999,8704401706019656847,1.6802378E27 +503142435912090290,4817.213,0.3841225840,9.127318172586,1196628311801793151,1.0 +227184013575342512,5661.152,0.3130044658,529.075843748348,-5761755906933392982,-1.2703339E27 +"",1738.648,0.1635048458,456.258113062836,1937750570385130232,-547.00464 +786474267467256790,9999.999,0.7247985984,246.846730055841,1,221027.05 +5722907039468293,1494.945,0.4833744532,521.189835443092,"",79.40845 +498893766826535453,-9999.999,0.8314235080,448.448825006248,-4757790292222279178,-3.1900585E26 +746704316885770352,7426.798,0.0697542832,933.708683277599,8123776869365306368,8367.281 +242857837293930678,72.927,0.4892817062,593.300614078262,1525346850676520527,2.4633369E-28 +267180693163965584,5567.777,0.1129542164,862.736499156329,-7943660520670333047,-2.5013383E34 +863179009301888915,"",0.5913615439,559.544347644078,-6218179828993870307,"" +628941572681954168,672.616,0.1120588881,726.724012853591,8295107456982566271,-3.601133E-26 +195085765272922078,3267.147,0.6227582353,942.184595963346,-7901787385806075056,5.208774E-15 +54795121483994580,7573.962,0.0232430822,530.957948073184,2519241821679825052,2.3091825E27 +11025911078984278,9080.678,0.9256236807,785.967159932085,3898714205093691132,-5.299004E-27 +585416995369880382,3180.974,0.6636002892,399.741020816154,8287847947794918753,-4.328721E-24 +93056155170236349,9851.699,0.4570285677,870.296415087233,7848530628017602134,1.6895455E-36 +273518414773911153,541.209,0.0030913568,794.949286358568,-9209970347163780353,9.984763E-9 +950579123902836569,2958.730,"","",-8408605963020785199,-152976.69 +294812000018063416,5947.866,0.3776552591,71.442236681325,"",1.89517168E8 +999999999999999999,4313.357,0.0052879447,534.027817103663,-1499033162028359415,3.3824432E-10 +47797026331225179,2523.889,0.0202628702,283.620687243088,9223372036854775807,-9.904269E28 +945353911554496947,9714.635,0.5445525623,577.494999825168,-4851177952341996882,5.4106984E7 +"",706.236,0.0574360283,868.681442920555,0,-9.979879E32 +362850502113699515,3648.749,0.4765617370,751.474476720116,4674971002257188808,1.8771697E32 +263209284245630819,9532.742,0.5552251822,191.831875430596,-2824608313884610115,1.6279153E-21 +553774330947428625,166.791,0.9306022422,79.427117038552,-5012912027038200187,9.0646643E10 +607563227695541003,4512.145,0.6783432188,901.316884058369,7603145021478615614,NaN +159291016019700529,5240.183,0.7258544713,350.623491123040,-1,-3.2958715E-15 +619900339634591041,7564.722,0.1887079611,994.209902879431,-6112563531377075787,-2.2229757E-33 +831245367052238890,4229.978,0.5127881028,639.349614940198,7280508167761829381,1.8915695E21 +868372001868812448,7565.051,0.1726353517,344.897092976059,1197973862304902753,-3.4028235E38 +"",8757.507,0.8479961242,429.256072226970,-503988022213926193,1.22784955E-32 +590855156246510004,1298.748,0.3022440975,961.920179785774,-4325271223339769315,-7.483853E-22 diff --git a/tests/src/test/resources/decimal-test.orc b/tests/src/test/resources/decimal-test.orc new file mode 100644 index 00000000000..8396738735f Binary files /dev/null and b/tests/src/test/resources/decimal-test.orc differ diff --git a/tests/src/test/resources/decimal-test.parquet b/tests/src/test/resources/decimal-test.parquet new file mode 100644 index 00000000000..2d029c704be Binary files /dev/null and b/tests/src/test/resources/decimal-test.parquet differ diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/ParquetScanSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/ParquetScanSuite.scala index 14cf270bc93..0942b2c0346 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/ParquetScanSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/ParquetScanSuite.scala @@ -16,7 +16,11 @@ package com.nvidia.spark.rapids +import java.io.File +import java.nio.file.Files + import org.apache.spark.SparkConf +import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.functions.col class ParquetScanSuite extends SparkQueryCompareTestSuite { @@ -38,4 +42,18 @@ class ParquetScanSuite extends SparkQueryCompareTestSuite { frameFromParquet("timestamp-date-test.parquet")) { frame => frame.select(col("*")) } + + // Column schema of decimal-test.parquet is: [c_0: decimal(18, 0), c_1: decimal(7, 3), + // c_2: decimal(10, 10), c_3: decimal(15, 12), c_4: int64, c_5: float] + testSparkResultsAreEqual("Test Parquet decimal stored as INT32/64", + frameFromParquet("decimal-test.parquet")) { + frame => frame.select(col("*")) + } + + // Column schema of decimal-test-legacy.parquet is: [c_0: decimal(18, 0), c_1: decimal(7, 3), + // c_2: decimal(10, 10), c_3: decimal(15, 12), c_4: int64, c_5: float] + testSparkResultsAreEqual("Test Parquet decimal stored as FIXED_LEN_BYTE_ARRAY", + frameFromParquet("decimal-test-legacy.parquet")) { + frame => frame.select(col("*")) + } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/unit/DecimalUnitTest.scala b/tests/src/test/scala/com/nvidia/spark/rapids/unit/DecimalUnitTest.scala index bfb7886b4ec..924bd4de152 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/unit/DecimalUnitTest.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/unit/DecimalUnitTest.scala @@ -21,10 +21,16 @@ import java.math.RoundingMode import scala.util.Random import ai.rapids.cudf.{ColumnVector, DType, HostColumnVector} -import com.nvidia.spark.rapids.{GpuAlias, GpuColumnVector, GpuIsNotNull, GpuIsNull, GpuLiteral, GpuOverrides, GpuScalar, GpuUnitTests, HostColumnarToGpu, RapidsConf} +import com.nvidia.spark.rapids.{GpuAlias, GpuBatchScanExec, GpuColumnVector, GpuIsNotNull, GpuIsNull, GpuLiteral, GpuOverrides, GpuScalar, GpuUnitTests, HostColumnarToGpu, RapidsConf} +import org.apache.spark.SparkConf +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Literal} -import org.apache.spark.sql.types.{Decimal, DecimalType} +import org.apache.spark.sql.execution.FileSourceScanExec +import org.apache.spark.sql.execution.datasources.v2.BatchScanExec +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.rapids.GpuFileSourceScanExec +import org.apache.spark.sql.types.{Decimal, DecimalType, IntegerType, LongType, StructField, StructType} class DecimalUnitTest extends GpuUnitTests { Random.setSeed(1234L) @@ -260,4 +266,34 @@ class DecimalUnitTest extends GpuUnitTests { } } } + + test("test type checking of Scans") { + val conf = new SparkConf().set(RapidsConf.DECIMAL_TYPE_ENABLED.key, "true") + .set(RapidsConf.TEST_ALLOWED_NONGPU.key, "BatchScanExec,ColumnarToRowExec,FileSourceScanExec") + val decimalCsvStruct = StructType(Array( + StructField("c_0", DecimalType(18, 0), true), + StructField("c_1", DecimalType(7, 3), true), + StructField("c_2", DecimalType(10, 10), true), + StructField("c_3", DecimalType(15, 12), true), + StructField("c_4", LongType, true), + StructField("c_5", IntegerType, true))) + + withGpuSparkSession((ss: SparkSession) => { + var rootPlan = frameFromOrc("decimal-test.orc")(ss).queryExecution.executedPlan + assert(rootPlan.map(p => p).exists(_.isInstanceOf[FileSourceScanExec])) + rootPlan = fromCsvDf("decimal-test.csv", decimalCsvStruct)(ss).queryExecution.executedPlan + assert(rootPlan.map(p => p).exists(_.isInstanceOf[FileSourceScanExec])) + rootPlan = frameFromParquet("decimal-test.parquet")(ss).queryExecution.executedPlan + assert(rootPlan.map(p => p).exists(_.isInstanceOf[GpuFileSourceScanExec])) + }, conf) + + withGpuSparkSession((ss: SparkSession) => { + var rootPlan = frameFromOrc("decimal-test.orc")(ss).queryExecution.executedPlan + assert(rootPlan.map(p => p).exists(_.isInstanceOf[BatchScanExec])) + rootPlan = fromCsvDf("decimal-test.csv", decimalCsvStruct)(ss).queryExecution.executedPlan + assert(rootPlan.map(p => p).exists(_.isInstanceOf[BatchScanExec])) + rootPlan = frameFromParquet("decimal-test.parquet")(ss).queryExecution.executedPlan + assert(rootPlan.map(p => p).exists(_.isInstanceOf[GpuBatchScanExec])) + }, conf.set(SQLConf.USE_V1_SOURCE_LIST.key, "")) + } }