diff --git a/docs/supported_ops.md b/docs/supported_ops.md
index 4a34de392f4..511a3611803 100644
--- a/docs/supported_ops.md
+++ b/docs/supported_ops.md
@@ -465,7 +465,7 @@ Accelerator supports are described below.
S |
S* |
S |
-NS |
+PS* (Only supported for Parquet) |
NS |
NS |
NS |
@@ -17821,7 +17821,7 @@ dates or timestamps, or for a lack of type coercion support.
S |
S |
S |
-NS |
+S |
|
NS |
|
diff --git a/integration_tests/src/main/python/data_gen.py b/integration_tests/src/main/python/data_gen.py
index 56d0cee1c46..42678406053 100644
--- a/integration_tests/src/main/python/data_gen.py
+++ b/integration_tests/src/main/python/data_gen.py
@@ -220,7 +220,6 @@ def __init__(self, precision=None, scale=None, nullable=True, special_cases=[]):
scale = 0
DECIMAL_MIN = Decimal('-' + ('9' * precision) + 'e' + str(-scale))
DECIMAL_MAX = Decimal(('9'* precision) + 'e' + str(-scale))
- special_cases = [Decimal('0'), Decimal(DECIMAL_MIN), Decimal(DECIMAL_MAX)]
super().__init__(DecimalType(precision, scale), nullable=nullable, special_cases=special_cases)
self._scale = scale
self._precision = precision
diff --git a/integration_tests/src/main/python/parquet_write_test.py b/integration_tests/src/main/python/parquet_write_test.py
index d7d81bc1038..e53af274eb8 100644
--- a/integration_tests/src/main/python/parquet_write_test.py
+++ b/integration_tests/src/main/python/parquet_write_test.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2020, NVIDIA CORPORATION.
+# Copyright (c) 2020-2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -20,6 +20,7 @@
from marks import *
from pyspark.sql.types import *
from spark_session import with_cpu_session, with_gpu_session, is_before_spark_310
+import random
# test with original parquet file reader, the multi-file parallel reader for cloud, and coalesce file reader for
# non-cloud
@@ -33,8 +34,12 @@
'spark.sql.legacy.parquet.int96RebaseModeInWrite': 'CORRECTED'}
parquet_write_gens_list = [
- [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,
- string_gen, boolean_gen, date_gen, timestamp_gen]]
+ [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,
+ string_gen, boolean_gen, date_gen, timestamp_gen],
+ pytest.param([byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,
+ string_gen, boolean_gen, date_gen, timestamp_gen, decimal_gen_default,
+ decimal_gen_scale_precision, decimal_gen_same_scale_precision, decimal_gen_64bit],
+ marks=pytest.mark.allow_non_gpu("CoalesceExec"))]
parquet_ts_write_options = ['INT96', 'TIMESTAMP_MICROS', 'TIMESTAMP_MILLIS']
@@ -42,7 +47,8 @@
@pytest.mark.parametrize('reader_confs', reader_opt_confs)
@pytest.mark.parametrize('v1_enabled_list', ["", "parquet"])
@pytest.mark.parametrize('ts_type', parquet_ts_write_options)
-def test_write_round_trip(spark_tmp_path, parquet_gens, v1_enabled_list, ts_type, reader_confs):
+def test_write_round_trip(spark_tmp_path, parquet_gens, v1_enabled_list, ts_type,
+ reader_confs):
gen_list = [('_c' + str(i), gen) for i, gen in enumerate(parquet_gens)]
data_path = spark_tmp_path + '/PARQUET_DATA'
all_confs = reader_confs.copy()
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
index 4c77c4a2a78..a111a8e6993 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
@@ -314,7 +314,7 @@ final class InsertIntoHadoopFsRelationCommandMeta(
willNotWorkOnGpu("JSON output is not supported")
None
case f if GpuOrcFileFormat.isSparkOrcFormat(f) =>
- GpuOrcFileFormat.tagGpuSupport(this, spark, cmd.options)
+ GpuOrcFileFormat.tagGpuSupport(this, spark, cmd.options, cmd.query.schema)
case _: ParquetFileFormat =>
GpuParquetFileFormat.tagGpuSupport(this, spark, cmd.options, cmd.query.schema)
case _: TextFileFormat =>
@@ -371,7 +371,7 @@ final class CreateDataSourceTableAsSelectCommandMeta(
// If that changes then this will start failing because we don't have a mapping.
gpuProvider = origProvider.getConstructor().newInstance() match {
case f: FileFormat if GpuOrcFileFormat.isSparkOrcFormat(f) =>
- GpuOrcFileFormat.tagGpuSupport(this, spark, cmd.table.storage.properties)
+ GpuOrcFileFormat.tagGpuSupport(this, spark, cmd.table.storage.properties, cmd.query.schema)
case _: ParquetFileFormat =>
GpuParquetFileFormat.tagGpuSupport(this, spark,
cmd.table.storage.properties, cmd.query.schema)
@@ -2344,7 +2344,9 @@ object GpuOverrides {
}),
exec[DataWritingCommandExec](
"Writing data",
- ExecChecks(TypeSig.commonCudfTypes, TypeSig.all),
+ ExecChecks((TypeSig.commonCudfTypes +
+ TypeSig.DECIMAL.withPsNote(TypeEnum.DECIMAL, "Only supported for Parquet")).nested(),
+ TypeSig.all),
(p, conf, parent, r) => new SparkPlanMeta[DataWritingCommandExec](p, conf, parent, r) {
override val childDataWriteCmds: scala.Seq[DataWritingCommandMeta[_]] =
Seq(GpuOverrides.wrapDataWriteCmds(p.cmd, conf, Some(this)))
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala
index 98ba2a92bbf..5668ac3bcd7 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2020, NVIDIA CORPORATION.
+ * Copyright (c) 2019-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -32,7 +32,7 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType
import org.apache.spark.sql.rapids.ColumnarWriteTaskStatsTracker
import org.apache.spark.sql.rapids.execution.TrampolineUtil
-import org.apache.spark.sql.types.{DataTypes, DateType, StructType, TimestampType}
+import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, DateType, DecimalType, MapType, StructType, TimestampType}
import org.apache.spark.sql.vectorized.ColumnarBatch
object GpuParquetFileFormat {
@@ -41,6 +41,13 @@ object GpuParquetFileFormat {
spark: SparkSession,
options: Map[String, String],
schema: StructType): Option[GpuParquetFileFormat] = {
+
+ val unSupportedTypes =
+ schema.filterNot(field => GpuOverrides.isSupportedType(field.dataType, allowDecimal = true))
+ if (unSupportedTypes.nonEmpty) {
+ meta.willNotWorkOnGpu(s"These types aren't supported for parquet $unSupportedTypes")
+ }
+
val sqlConf = spark.sessionState.conf
val parquetOptions = new ParquetOptions(options, sqlConf)
@@ -247,31 +254,45 @@ class GpuParquetWriter(
*/
override def write(batch: ColumnarBatch,
statsTrackers: Seq[ColumnarWriteTaskStatsTracker]): Unit = {
+ val outputMillis = outputTimestampType == ParquetOutputTimestampType.TIMESTAMP_MILLIS.toString
val newBatch =
- if (outputTimestampType == ParquetOutputTimestampType.TIMESTAMP_MILLIS.toString) {
- new ColumnarBatch(GpuColumnVector.extractColumns(batch).map {
- cv => {
- cv.dataType() match {
- case DataTypes.TimestampType => new GpuColumnVector(DataTypes.TimestampType,
- withResource(cv.getBase()) { v =>
- v.castTo(DType.TIMESTAMP_MILLISECONDS)
- })
- case _ => cv
- }
+ new ColumnarBatch(GpuColumnVector.extractColumns(batch).map {
+ cv => {
+ cv.dataType() match {
+ case DataTypes.TimestampType if outputMillis =>
+ new GpuColumnVector(DataTypes.TimestampType, withResource(cv.getBase()) { v =>
+ v.castTo(DType.TIMESTAMP_MILLISECONDS)
+ })
+ case d: DecimalType if d.precision < 10 =>
+ // There is a bug in Spark that causes a problem if we write Decimals with
+ // precision < 10 as Decimal64.
+ // https://issues.apache.org/jira/browse/SPARK-34167
+ new GpuColumnVector(d, withResource(cv.getBase()) { v =>
+ v.castTo(DType.create(DType.DTypeEnum.DECIMAL32, -d.scale))
+ })
+ case _ => cv
}
- })
- } else {
- batch
- }
+ }
+ })
super.write(newBatch, statsTrackers)
}
override val tableWriter: TableWriter = {
+ def precisionsList(t: DataType): Seq[Int] = {
+ t match {
+ case d: DecimalType => List(d.precision)
+ case s: StructType => s.flatMap(f => precisionsList(f.dataType))
+ case ArrayType(elementType, _) => precisionsList(elementType)
+ case _ => List.empty
+ }
+ }
+
val writeContext = new ParquetWriteSupport().init(conf)
val builder = ParquetWriterOptions.builder()
.withMetadata(writeContext.getExtraMetaData)
.withCompressionType(compressionType)
.withTimestampInt96(outputTimestampType == ParquetOutputTimestampType.INT96)
+ .withPrecisionValues(dataSchema.flatMap(f => precisionsList(f.dataType)):_*)
dataSchema.foreach(entry => {
if (entry.nullable) {
builder.withColumnNames(entry.name)
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala
index 714a16779b4..76915acaa70 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala
@@ -1313,7 +1313,7 @@ object SupportedOpsDocs {
println("S | ") // DATE
println("S | ") // TIMESTAMP
println("S | ") // STRING
- println("NS | ") // DECIMAL
+ println("S | ") // DECIMAL
println(" | ") // NULL
println("NS | ") // BINARY
println(" | ") // CALENDAR
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuOrcFileFormat.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuOrcFileFormat.scala
index 80e11ff338d..8bc99822d77 100644
--- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuOrcFileFormat.scala
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuOrcFileFormat.scala
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2020, NVIDIA CORPORATION.
+ * Copyright (c) 2020-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -44,7 +44,13 @@ object GpuOrcFileFormat extends Logging {
def tagGpuSupport(meta: RapidsMeta[_, _, _],
spark: SparkSession,
- options: Map[String, String]): Option[GpuOrcFileFormat] = {
+ options: Map[String, String],
+ schema: StructType): Option[GpuOrcFileFormat] = {
+
+ val unSupportedTypes = schema.filterNot(field => GpuOverrides.isSupportedType(field.dataType))
+ if (unSupportedTypes.nonEmpty) {
+ meta.willNotWorkOnGpu(s"These types aren't supported for orc $unSupportedTypes")
+ }
if (!meta.conf.isOrcEnabled) {
meta.willNotWorkOnGpu("ORC input and output has been disabled. To enable set" +