diff --git a/docs/supported_ops.md b/docs/supported_ops.md index 4a34de392f4..511a3611803 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -465,7 +465,7 @@ Accelerator supports are described below. S S* S -NS +PS* (Only supported for Parquet) NS NS NS @@ -17821,7 +17821,7 @@ dates or timestamps, or for a lack of type coercion support. S S S -NS +S NS diff --git a/integration_tests/src/main/python/data_gen.py b/integration_tests/src/main/python/data_gen.py index 56d0cee1c46..42678406053 100644 --- a/integration_tests/src/main/python/data_gen.py +++ b/integration_tests/src/main/python/data_gen.py @@ -220,7 +220,6 @@ def __init__(self, precision=None, scale=None, nullable=True, special_cases=[]): scale = 0 DECIMAL_MIN = Decimal('-' + ('9' * precision) + 'e' + str(-scale)) DECIMAL_MAX = Decimal(('9'* precision) + 'e' + str(-scale)) - special_cases = [Decimal('0'), Decimal(DECIMAL_MIN), Decimal(DECIMAL_MAX)] super().__init__(DecimalType(precision, scale), nullable=nullable, special_cases=special_cases) self._scale = scale self._precision = precision diff --git a/integration_tests/src/main/python/parquet_write_test.py b/integration_tests/src/main/python/parquet_write_test.py index d7d81bc1038..e53af274eb8 100644 --- a/integration_tests/src/main/python/parquet_write_test.py +++ b/integration_tests/src/main/python/parquet_write_test.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. +# Copyright (c) 2020-2021, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ from marks import * from pyspark.sql.types import * from spark_session import with_cpu_session, with_gpu_session, is_before_spark_310 +import random # test with original parquet file reader, the multi-file parallel reader for cloud, and coalesce file reader for # non-cloud @@ -33,8 +34,12 @@ 'spark.sql.legacy.parquet.int96RebaseModeInWrite': 'CORRECTED'} parquet_write_gens_list = [ - [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, - string_gen, boolean_gen, date_gen, timestamp_gen]] + [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, + string_gen, boolean_gen, date_gen, timestamp_gen], + pytest.param([byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, + string_gen, boolean_gen, date_gen, timestamp_gen, decimal_gen_default, + decimal_gen_scale_precision, decimal_gen_same_scale_precision, decimal_gen_64bit], + marks=pytest.mark.allow_non_gpu("CoalesceExec"))] parquet_ts_write_options = ['INT96', 'TIMESTAMP_MICROS', 'TIMESTAMP_MILLIS'] @@ -42,7 +47,8 @@ @pytest.mark.parametrize('reader_confs', reader_opt_confs) @pytest.mark.parametrize('v1_enabled_list', ["", "parquet"]) @pytest.mark.parametrize('ts_type', parquet_ts_write_options) -def test_write_round_trip(spark_tmp_path, parquet_gens, v1_enabled_list, ts_type, reader_confs): +def test_write_round_trip(spark_tmp_path, parquet_gens, v1_enabled_list, ts_type, + reader_confs): gen_list = [('_c' + str(i), gen) for i, gen in enumerate(parquet_gens)] data_path = spark_tmp_path + '/PARQUET_DATA' all_confs = reader_confs.copy() diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 4c77c4a2a78..a111a8e6993 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -314,7 +314,7 @@ final class InsertIntoHadoopFsRelationCommandMeta( willNotWorkOnGpu("JSON output is not supported") None case f if GpuOrcFileFormat.isSparkOrcFormat(f) => - GpuOrcFileFormat.tagGpuSupport(this, spark, cmd.options) + GpuOrcFileFormat.tagGpuSupport(this, spark, cmd.options, cmd.query.schema) case _: ParquetFileFormat => GpuParquetFileFormat.tagGpuSupport(this, spark, cmd.options, cmd.query.schema) case _: TextFileFormat => @@ -371,7 +371,7 @@ final class CreateDataSourceTableAsSelectCommandMeta( // If that changes then this will start failing because we don't have a mapping. gpuProvider = origProvider.getConstructor().newInstance() match { case f: FileFormat if GpuOrcFileFormat.isSparkOrcFormat(f) => - GpuOrcFileFormat.tagGpuSupport(this, spark, cmd.table.storage.properties) + GpuOrcFileFormat.tagGpuSupport(this, spark, cmd.table.storage.properties, cmd.query.schema) case _: ParquetFileFormat => GpuParquetFileFormat.tagGpuSupport(this, spark, cmd.table.storage.properties, cmd.query.schema) @@ -2344,7 +2344,9 @@ object GpuOverrides { }), exec[DataWritingCommandExec]( "Writing data", - ExecChecks(TypeSig.commonCudfTypes, TypeSig.all), + ExecChecks((TypeSig.commonCudfTypes + + TypeSig.DECIMAL.withPsNote(TypeEnum.DECIMAL, "Only supported for Parquet")).nested(), + TypeSig.all), (p, conf, parent, r) => new SparkPlanMeta[DataWritingCommandExec](p, conf, parent, r) { override val childDataWriteCmds: scala.Seq[DataWritingCommandMeta[_]] = Seq(GpuOverrides.wrapDataWriteCmds(p.cmd, conf, Some(this))) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala index 98ba2a92bbf..5668ac3bcd7 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * Copyright (c) 2019-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,7 +32,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType import org.apache.spark.sql.rapids.ColumnarWriteTaskStatsTracker import org.apache.spark.sql.rapids.execution.TrampolineUtil -import org.apache.spark.sql.types.{DataTypes, DateType, StructType, TimestampType} +import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, DateType, DecimalType, MapType, StructType, TimestampType} import org.apache.spark.sql.vectorized.ColumnarBatch object GpuParquetFileFormat { @@ -41,6 +41,13 @@ object GpuParquetFileFormat { spark: SparkSession, options: Map[String, String], schema: StructType): Option[GpuParquetFileFormat] = { + + val unSupportedTypes = + schema.filterNot(field => GpuOverrides.isSupportedType(field.dataType, allowDecimal = true)) + if (unSupportedTypes.nonEmpty) { + meta.willNotWorkOnGpu(s"These types aren't supported for parquet $unSupportedTypes") + } + val sqlConf = spark.sessionState.conf val parquetOptions = new ParquetOptions(options, sqlConf) @@ -247,31 +254,45 @@ class GpuParquetWriter( */ override def write(batch: ColumnarBatch, statsTrackers: Seq[ColumnarWriteTaskStatsTracker]): Unit = { + val outputMillis = outputTimestampType == ParquetOutputTimestampType.TIMESTAMP_MILLIS.toString val newBatch = - if (outputTimestampType == ParquetOutputTimestampType.TIMESTAMP_MILLIS.toString) { - new ColumnarBatch(GpuColumnVector.extractColumns(batch).map { - cv => { - cv.dataType() match { - case DataTypes.TimestampType => new GpuColumnVector(DataTypes.TimestampType, - withResource(cv.getBase()) { v => - v.castTo(DType.TIMESTAMP_MILLISECONDS) - }) - case _ => cv - } + new ColumnarBatch(GpuColumnVector.extractColumns(batch).map { + cv => { + cv.dataType() match { + case DataTypes.TimestampType if outputMillis => + new GpuColumnVector(DataTypes.TimestampType, withResource(cv.getBase()) { v => + v.castTo(DType.TIMESTAMP_MILLISECONDS) + }) + case d: DecimalType if d.precision < 10 => + // There is a bug in Spark that causes a problem if we write Decimals with + // precision < 10 as Decimal64. + // https://issues.apache.org/jira/browse/SPARK-34167 + new GpuColumnVector(d, withResource(cv.getBase()) { v => + v.castTo(DType.create(DType.DTypeEnum.DECIMAL32, -d.scale)) + }) + case _ => cv } - }) - } else { - batch - } + } + }) super.write(newBatch, statsTrackers) } override val tableWriter: TableWriter = { + def precisionsList(t: DataType): Seq[Int] = { + t match { + case d: DecimalType => List(d.precision) + case s: StructType => s.flatMap(f => precisionsList(f.dataType)) + case ArrayType(elementType, _) => precisionsList(elementType) + case _ => List.empty + } + } + val writeContext = new ParquetWriteSupport().init(conf) val builder = ParquetWriterOptions.builder() .withMetadata(writeContext.getExtraMetaData) .withCompressionType(compressionType) .withTimestampInt96(outputTimestampType == ParquetOutputTimestampType.INT96) + .withPrecisionValues(dataSchema.flatMap(f => precisionsList(f.dataType)):_*) dataSchema.foreach(entry => { if (entry.nullable) { builder.withColumnNames(entry.name) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala index 714a16779b4..76915acaa70 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala @@ -1313,7 +1313,7 @@ object SupportedOpsDocs { println("S") // DATE println("S") // TIMESTAMP println("S") // STRING - println("NS") // DECIMAL + println("S") // DECIMAL println("") // NULL println("NS") // BINARY println("") // CALENDAR diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuOrcFileFormat.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuOrcFileFormat.scala index 80e11ff338d..8bc99822d77 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuOrcFileFormat.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuOrcFileFormat.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, NVIDIA CORPORATION. + * Copyright (c) 2020-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -44,7 +44,13 @@ object GpuOrcFileFormat extends Logging { def tagGpuSupport(meta: RapidsMeta[_, _, _], spark: SparkSession, - options: Map[String, String]): Option[GpuOrcFileFormat] = { + options: Map[String, String], + schema: StructType): Option[GpuOrcFileFormat] = { + + val unSupportedTypes = schema.filterNot(field => GpuOverrides.isSupportedType(field.dataType)) + if (unSupportedTypes.nonEmpty) { + meta.willNotWorkOnGpu(s"These types aren't supported for orc $unSupportedTypes") + } if (!meta.conf.isOrcEnabled) { meta.willNotWorkOnGpu("ORC input and output has been disabled. To enable set" +