Skip to content

Commit

Permalink
Decimal Support for writing Parquet (NVIDIA#1531)
Browse files Browse the repository at this point in the history
* Decimal Support for writing Parquet

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* addressed review comments

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* updated static doc

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* generated supported_ops

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* addressed review comments

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

Co-authored-by: Raza Jafri <rjafri@nvidia.com>
  • Loading branch information
razajafri and razajafri authored Jan 23, 2021
1 parent 625ef44 commit 52cdef7
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 29 deletions.
4 changes: 2 additions & 2 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ Accelerator supports are described below.
<td>S</td>
<td>S*</td>
<td>S</td>
<td><b>NS</b></td>
<td><em>PS* (Only supported for Parquet)</em></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
Expand Down Expand Up @@ -17821,7 +17821,7 @@ dates or timestamps, or for a lack of type coercion support.
<td>S</td>
<td>S</td>
<td>S</td>
<td><b>NS</b></td>
<td>S</td>
<td></td>
<td><b>NS</b></td>
<td></td>
Expand Down
1 change: 0 additions & 1 deletion integration_tests/src/main/python/data_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,6 @@ def __init__(self, precision=None, scale=None, nullable=True, special_cases=[]):
scale = 0
DECIMAL_MIN = Decimal('-' + ('9' * precision) + 'e' + str(-scale))
DECIMAL_MAX = Decimal(('9'* precision) + 'e' + str(-scale))
special_cases = [Decimal('0'), Decimal(DECIMAL_MIN), Decimal(DECIMAL_MAX)]
super().__init__(DecimalType(precision, scale), nullable=nullable, special_cases=special_cases)
self._scale = scale
self._precision = precision
Expand Down
14 changes: 10 additions & 4 deletions integration_tests/src/main/python/parquet_write_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020, NVIDIA CORPORATION.
# Copyright (c) 2020-2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -20,6 +20,7 @@
from marks import *
from pyspark.sql.types import *
from spark_session import with_cpu_session, with_gpu_session, is_before_spark_310
import random

# test with original parquet file reader, the multi-file parallel reader for cloud, and coalesce file reader for
# non-cloud
Expand All @@ -33,16 +34,21 @@
'spark.sql.legacy.parquet.int96RebaseModeInWrite': 'CORRECTED'}

parquet_write_gens_list = [
[byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,
string_gen, boolean_gen, date_gen, timestamp_gen]]
[byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,
string_gen, boolean_gen, date_gen, timestamp_gen],
pytest.param([byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,
string_gen, boolean_gen, date_gen, timestamp_gen, decimal_gen_default,
decimal_gen_scale_precision, decimal_gen_same_scale_precision, decimal_gen_64bit],
marks=pytest.mark.allow_non_gpu("CoalesceExec"))]

parquet_ts_write_options = ['INT96', 'TIMESTAMP_MICROS', 'TIMESTAMP_MILLIS']

@pytest.mark.parametrize('parquet_gens', parquet_write_gens_list, ids=idfn)
@pytest.mark.parametrize('reader_confs', reader_opt_confs)
@pytest.mark.parametrize('v1_enabled_list', ["", "parquet"])
@pytest.mark.parametrize('ts_type', parquet_ts_write_options)
def test_write_round_trip(spark_tmp_path, parquet_gens, v1_enabled_list, ts_type, reader_confs):
def test_write_round_trip(spark_tmp_path, parquet_gens, v1_enabled_list, ts_type,
reader_confs):
gen_list = [('_c' + str(i), gen) for i, gen in enumerate(parquet_gens)]
data_path = spark_tmp_path + '/PARQUET_DATA'
all_confs = reader_confs.copy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ final class InsertIntoHadoopFsRelationCommandMeta(
willNotWorkOnGpu("JSON output is not supported")
None
case f if GpuOrcFileFormat.isSparkOrcFormat(f) =>
GpuOrcFileFormat.tagGpuSupport(this, spark, cmd.options)
GpuOrcFileFormat.tagGpuSupport(this, spark, cmd.options, cmd.query.schema)
case _: ParquetFileFormat =>
GpuParquetFileFormat.tagGpuSupport(this, spark, cmd.options, cmd.query.schema)
case _: TextFileFormat =>
Expand Down Expand Up @@ -371,7 +371,7 @@ final class CreateDataSourceTableAsSelectCommandMeta(
// If that changes then this will start failing because we don't have a mapping.
gpuProvider = origProvider.getConstructor().newInstance() match {
case f: FileFormat if GpuOrcFileFormat.isSparkOrcFormat(f) =>
GpuOrcFileFormat.tagGpuSupport(this, spark, cmd.table.storage.properties)
GpuOrcFileFormat.tagGpuSupport(this, spark, cmd.table.storage.properties, cmd.query.schema)
case _: ParquetFileFormat =>
GpuParquetFileFormat.tagGpuSupport(this, spark,
cmd.table.storage.properties, cmd.query.schema)
Expand Down Expand Up @@ -2344,7 +2344,9 @@ object GpuOverrides {
}),
exec[DataWritingCommandExec](
"Writing data",
ExecChecks(TypeSig.commonCudfTypes, TypeSig.all),
ExecChecks((TypeSig.commonCudfTypes +
TypeSig.DECIMAL.withPsNote(TypeEnum.DECIMAL, "Only supported for Parquet")).nested(),
TypeSig.all),
(p, conf, parent, r) => new SparkPlanMeta[DataWritingCommandExec](p, conf, parent, r) {
override val childDataWriteCmds: scala.Seq[DataWritingCommandMeta[_]] =
Seq(GpuOverrides.wrapDataWriteCmds(p.cmd, conf, Some(this)))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2020, NVIDIA CORPORATION.
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType
import org.apache.spark.sql.rapids.ColumnarWriteTaskStatsTracker
import org.apache.spark.sql.rapids.execution.TrampolineUtil
import org.apache.spark.sql.types.{DataTypes, DateType, StructType, TimestampType}
import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, DateType, DecimalType, MapType, StructType, TimestampType}
import org.apache.spark.sql.vectorized.ColumnarBatch

object GpuParquetFileFormat {
Expand All @@ -41,6 +41,13 @@ object GpuParquetFileFormat {
spark: SparkSession,
options: Map[String, String],
schema: StructType): Option[GpuParquetFileFormat] = {

val unSupportedTypes =
schema.filterNot(field => GpuOverrides.isSupportedType(field.dataType, allowDecimal = true))
if (unSupportedTypes.nonEmpty) {
meta.willNotWorkOnGpu(s"These types aren't supported for parquet $unSupportedTypes")
}

val sqlConf = spark.sessionState.conf
val parquetOptions = new ParquetOptions(options, sqlConf)

Expand Down Expand Up @@ -247,31 +254,45 @@ class GpuParquetWriter(
*/
override def write(batch: ColumnarBatch,
statsTrackers: Seq[ColumnarWriteTaskStatsTracker]): Unit = {
val outputMillis = outputTimestampType == ParquetOutputTimestampType.TIMESTAMP_MILLIS.toString
val newBatch =
if (outputTimestampType == ParquetOutputTimestampType.TIMESTAMP_MILLIS.toString) {
new ColumnarBatch(GpuColumnVector.extractColumns(batch).map {
cv => {
cv.dataType() match {
case DataTypes.TimestampType => new GpuColumnVector(DataTypes.TimestampType,
withResource(cv.getBase()) { v =>
v.castTo(DType.TIMESTAMP_MILLISECONDS)
})
case _ => cv
}
new ColumnarBatch(GpuColumnVector.extractColumns(batch).map {
cv => {
cv.dataType() match {
case DataTypes.TimestampType if outputMillis =>
new GpuColumnVector(DataTypes.TimestampType, withResource(cv.getBase()) { v =>
v.castTo(DType.TIMESTAMP_MILLISECONDS)
})
case d: DecimalType if d.precision < 10 =>
// There is a bug in Spark that causes a problem if we write Decimals with
// precision < 10 as Decimal64.
// https://issues.apache.org/jira/browse/SPARK-34167
new GpuColumnVector(d, withResource(cv.getBase()) { v =>
v.castTo(DType.create(DType.DTypeEnum.DECIMAL32, -d.scale))
})
case _ => cv
}
})
} else {
batch
}
}
})

super.write(newBatch, statsTrackers)
}
override val tableWriter: TableWriter = {
def precisionsList(t: DataType): Seq[Int] = {
t match {
case d: DecimalType => List(d.precision)
case s: StructType => s.flatMap(f => precisionsList(f.dataType))
case ArrayType(elementType, _) => precisionsList(elementType)
case _ => List.empty
}
}

val writeContext = new ParquetWriteSupport().init(conf)
val builder = ParquetWriterOptions.builder()
.withMetadata(writeContext.getExtraMetaData)
.withCompressionType(compressionType)
.withTimestampInt96(outputTimestampType == ParquetOutputTimestampType.INT96)
.withPrecisionValues(dataSchema.flatMap(f => precisionsList(f.dataType)):_*)
dataSchema.foreach(entry => {
if (entry.nullable) {
builder.withColumnNames(entry.name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1313,7 +1313,7 @@ object SupportedOpsDocs {
println("<td>S</td>") // DATE
println("<td>S</td>") // TIMESTAMP
println("<td>S</td>") // STRING
println("<td><b>NS</b></td>") // DECIMAL
println("<td>S</td>") // DECIMAL
println("<td></td>") // NULL
println("<td><b>NS</b></td>") // BINARY
println("<td></td>") // CALENDAR
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -44,7 +44,13 @@ object GpuOrcFileFormat extends Logging {

def tagGpuSupport(meta: RapidsMeta[_, _, _],
spark: SparkSession,
options: Map[String, String]): Option[GpuOrcFileFormat] = {
options: Map[String, String],
schema: StructType): Option[GpuOrcFileFormat] = {

val unSupportedTypes = schema.filterNot(field => GpuOverrides.isSupportedType(field.dataType))
if (unSupportedTypes.nonEmpty) {
meta.willNotWorkOnGpu(s"These types aren't supported for orc $unSupportedTypes")
}

if (!meta.conf.isOrcEnabled) {
meta.willNotWorkOnGpu("ORC input and output has been disabled. To enable set" +
Expand Down

0 comments on commit 52cdef7

Please sign in to comment.