Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check schema compatibility when building parquet readers #5434

Merged
merged 13 commits into from
May 12, 2022
70 changes: 68 additions & 2 deletions integration_tests/src/main/python/parquet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

import pytest

from asserts import assert_cpu_and_gpu_are_equal_collect_with_capture, assert_gpu_and_cpu_are_equal_collect, assert_gpu_fallback_collect, assert_gpu_and_cpu_are_equal_sql
from asserts import assert_cpu_and_gpu_are_equal_collect_with_capture, assert_gpu_and_cpu_are_equal_collect, \
assert_gpu_fallback_collect, assert_gpu_and_cpu_are_equal_sql, assert_gpu_and_cpu_error, assert_py4j_exception
from data_gen import *
from marks import *
from pyspark.sql.types import *
Expand Down Expand Up @@ -829,4 +830,69 @@ def test_parquet_read_case_insensitivity(spark_tmp_path):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: spark.read.parquet(data_path).select('one', 'two', 'three'),
{'spark.sql.caseSensitive': 'false'}
)
)


def test_parquet_check_schema_compatibility(spark_tmp_path):
data_path = spark_tmp_path + '/PARQUET_DATA'
gen_list = [('int', int_gen), ('long', long_gen), ('dec32', decimal_gen_32bit)]
with_cpu_session(lambda spark: gen_df(spark, gen_list).coalesce(1).write.parquet(data_path))

read_int_as_long = StructType(
[StructField('long', LongType()), StructField('int', LongType())])
assert_gpu_and_cpu_error(
lambda spark: spark.read.schema(read_int_as_long).parquet(data_path).collect(),
conf={},
error_message='Parquet column cannot be converted in')

read_dec32_as_dec64 = StructType(
[StructField('int', IntegerType()), StructField('dec32', DecimalType(15, 10))])
assert_gpu_and_cpu_error(
lambda spark: spark.read.schema(read_dec32_as_dec64).parquet(data_path).collect(),
conf={},
error_message='Parquet column cannot be converted in')


# For nested types, GPU throws incompatible exception with a different message from CPU.
def test_parquet_check_schema_compatibility_nested_types(spark_tmp_path):
data_path = spark_tmp_path + '/PARQUET_DATA'
gen_list = [('array_long', ArrayGen(long_gen)),
('array_array_int', ArrayGen(ArrayGen(int_gen))),
('struct_float', StructGen([('f', float_gen), ('d', double_gen)])),
('struct_array_int', StructGen([('a', ArrayGen(int_gen))])),
('map', map_string_string_gen[0])]
with_cpu_session(lambda spark: gen_df(spark, gen_list).coalesce(1).write.parquet(data_path))

read_array_long_as_int = StructType([StructField('array_long', ArrayType(IntegerType()))])
assert_py4j_exception(
lambda: with_gpu_session(
lambda spark: spark.read.schema(read_array_long_as_int).parquet(data_path).collect()),
error_message='Parquet column cannot be converted in')

read_arr_arr_int_as_long = StructType(
[StructField('array_array_int', ArrayType(ArrayType(LongType())))])
assert_py4j_exception(
lambda: with_gpu_session(
lambda spark: spark.read.schema(read_arr_arr_int_as_long).parquet(data_path).collect()),
error_message='Parquet column cannot be converted in')

read_struct_flt_as_dbl = StructType([StructField(
'struct_float', StructType([StructField('f', DoubleType())]))])
assert_py4j_exception(
lambda: with_gpu_session(
lambda spark: spark.read.schema(read_struct_flt_as_dbl).parquet(data_path).collect()),
error_message='Parquet column cannot be converted in')

read_struct_arr_int_as_long = StructType([StructField(
'struct_array_int', StructType([StructField('a', ArrayType(LongType()))]))])
assert_py4j_exception(
lambda: with_gpu_session(
lambda spark: spark.read.schema(read_struct_arr_int_as_long).parquet(data_path).collect()),
error_message='Parquet column cannot be converted in')

read_map_str_str_as_str_int = StructType([StructField(
'map', MapType(StringType(), IntegerType()))])
assert_py4j_exception(
lambda: with_gpu_session(
lambda spark: spark.read.schema(read_map_str_str_as_str_int).parquet(data_path).collect()),
error_message='Parquet column cannot be converted in')
194 changes: 192 additions & 2 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ import org.apache.parquet.filter2.predicate.FilterApi
import org.apache.parquet.format.converter.ParquetMetadataConverter
import org.apache.parquet.hadoop.{ParquetFileReader, ParquetInputFormat}
import org.apache.parquet.hadoop.metadata._
import org.apache.parquet.schema.{GroupType, MessageType, OriginalType, PrimitiveType, Type, Types}
import org.apache.parquet.schema.{DecimalMetadata, GroupType, MessageType, OriginalType, PrimitiveType, Type, Types}
import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName

import org.apache.spark.TaskContext
Expand All @@ -55,7 +55,7 @@ import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
import org.apache.spark.sql.execution.QueryExecutionException
import org.apache.spark.sql.execution.datasources.{DataSourceUtils, PartitionedFile, PartitioningAwareFileIndex}
import org.apache.spark.sql.execution.datasources.{DataSourceUtils, PartitionedFile, PartitioningAwareFileIndex, SchemaColumnConvertNotSupportedException}
import org.apache.spark.sql.execution.datasources.parquet.ParquetReadSupport
import org.apache.spark.sql.execution.datasources.v2.{FilePartitionReaderFactory, FileScan}
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
Expand Down Expand Up @@ -411,6 +411,10 @@ private case class GpuParquetFileFilterHandler(@transient sqlConf: SQLConf) exte
withResource(new NvtxRange("clipSchema", NvtxColor.DARK_GREEN)) { _ =>
val clippedSchemaTmp = ParquetReadSupport.clipParquetSchema(fileSchema, readDataSchema,
isCaseSensitive)
// Check if the read schema is compatible with the file schema.
checkSchemaCompat(clippedSchemaTmp, readDataSchema,
(t: Type, d: DataType) => throwTypeIncompatibleError(t, d, file.filePath),
isCaseSensitive)
// ParquetReadSupport.clipParquetSchema does most of what we want, but it includes
// everything in readDataSchema, even if it is not in fileSchema we want to remove those
// for our own purposes
Expand All @@ -428,6 +432,192 @@ private case class GpuParquetFileFilterHandler(@transient sqlConf: SQLConf) exte
hasInt96Timestamps)
}
}

/**
* Recursively check if the read schema is compatible with the file schema. The errorCallback
* will be invoked to throw an exception once any incompatible type pairs are found.
*
* The function assumes all elements in read schema are included in file schema, so please
* run this check after clipping read schema upon file schema.
*
* The function only accepts top-level schemas, which means structures of root columns. Based
* on this assumption, it can infer root types from input schemas.
*
* @param fileType input file's Parquet schema
* @param readType spark type read from Parquet file
* @param errorCallback call back function to throw exception if type mismatch
* @param rootFileType file type of each root column
* @param rootReadType read type of each root column
*/
private def checkSchemaCompat(fileType: Type,
readType: DataType,
errorCallback: (Type, DataType) => Unit,
isCaseSensitive: Boolean,
rootFileType: Option[Type] = None,
rootReadType: Option[DataType] = None): Unit = {
readType match {
case struct: StructType =>
val fileFieldMap = fileType.asGroupType().getFields.asScala
.map { f =>
(if (isCaseSensitive) f.getName else f.getName.toLowerCase(Locale.ROOT)) -> f
}.toMap
struct.fields.foreach { f =>
val curFile = fileFieldMap(
if (isCaseSensitive) f.name else f.name.toLowerCase(Locale.ROOT))
checkSchemaCompat(curFile,
f.dataType,
errorCallback,
isCaseSensitive,
// Record root types for each column, so as to throw a readable exception
// over nested types.
Some(rootFileType.getOrElse(curFile)),
Some(rootReadType.getOrElse(f.dataType)))
}

case array: ArrayType =>
val fileChild = fileType.asGroupType().getType(0)
.asGroupType().getType(0)
checkSchemaCompat(fileChild, array.elementType, errorCallback, isCaseSensitive,
rootFileType, rootReadType)

case map: MapType =>
val parquetMap = fileType.asGroupType().getType(0).asGroupType()
val parquetMapKey = parquetMap.getType(0)
val parquetMapValue = parquetMap.getType(1)
checkSchemaCompat(parquetMapKey, map.keyType, errorCallback, isCaseSensitive,
rootFileType, rootReadType)
checkSchemaCompat(parquetMapValue, map.valueType, errorCallback, isCaseSensitive,
rootFileType, rootReadType)

case dt =>
checkPrimitiveCompat(fileType.asPrimitiveType(),
dt,
() => errorCallback(rootFileType.get, rootReadType.get))
}
}

/**
* Check the compatibility over primitive types. This function refers to the `getUpdater` method
* of [[org.apache.spark.sql.execution.datasources.parquet.ParquetVectorUpdaterFactory]].
jlowe marked this conversation as resolved.
Show resolved Hide resolved
*
* To avoid unnecessary pattern matching, this function is designed to return or throw ASAP.
*
* This function uses some deprecated Parquet APIs, because Spark 3.1 is relied on parquet-mr
* of an older version.
*/
@scala.annotation.nowarn("msg=method getDecimalMetadata in class PrimitiveType is deprecated")
private def checkPrimitiveCompat(pt: PrimitiveType,
dt: DataType,
errorCallback: () => Unit): Unit = {
pt.getPrimitiveTypeName match {
case PrimitiveTypeName.BOOLEAN if dt == DataTypes.BooleanType =>
return

// TODO: add YearMonthIntervalType
case PrimitiveTypeName.INT32 =>
if (dt == DataTypes.IntegerType || canReadAsIntDecimal(pt, dt)) {
return
}
// TODO: After we deprecate Spark 3.1, replace OriginalType with LogicalTypeAnnotation
if (dt == DataTypes.LongType && pt.getOriginalType == OriginalType.UINT_32) {
return
}
// TODO: Add below converters for INT32. Converters work when evolving schema over cuDF
// table read from Parquet file. https://github.com/NVIDIA/spark-rapids/issues/5445
if (dt == DataTypes.ByteType || dt == DataTypes.ShortType || dt == DataTypes.DateType) {
Copy link
Collaborator Author

@sperlingxx sperlingxx May 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added downcast converters for INT_32 in this PR to close #5445, since UT cases of parquet writing fails if these combinations are disabled.

return
}

// TODO: add DayTimeIntervalType
case PrimitiveTypeName.INT64 =>
if (dt == DataTypes.LongType || canReadAsLongDecimal(pt, dt)) {
return
}
// TODO: After we deprecate Spark 3.1, replace OriginalType with LogicalTypeAnnotation
if (isLongDecimal(dt) && pt.getOriginalType == OriginalType.UINT_64) {
return
}
if (pt.getOriginalType == OriginalType.TIMESTAMP_MICROS ||
pt.getOriginalType == OriginalType.TIMESTAMP_MILLIS) {
return
}

case PrimitiveTypeName.FLOAT if dt == DataTypes.FloatType =>
return

case PrimitiveTypeName.DOUBLE if dt == DataTypes.DoubleType =>
return

case PrimitiveTypeName.INT96 if dt == DataTypes.TimestampType =>
return

case PrimitiveTypeName.BINARY if dt == DataTypes.StringType ||
dt == DataTypes.BinaryType || canReadAsBinaryDecimal(pt, dt) =>
return

case PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY if canReadAsIntDecimal(pt, dt) ||
canReadAsLongDecimal(pt, dt) || canReadAsBinaryDecimal(pt, dt) =>
return

case _ =>
}

// If we get here, it means the combination of Spark and Parquet type is invalid or not
// supported.
errorCallback()
}

private def throwTypeIncompatibleError(parquetType: Type,
sparkType: DataType,
filePath: String): Unit = {
val exception = new SchemaColumnConvertNotSupportedException(
parquetType.getName,
parquetType.toString,
sparkType.catalogString)

// A copy of QueryExecutionErrors.unsupportedSchemaColumnConvertError introduced in 3.2+
// TODO: replace with unsupportedSchemaColumnConvertError after we deprecate Spark 3.1
val message = "Parquet column cannot be converted in " +
s"file $filePath. Column: ${parquetType.getName}, " +
s"Expected: ${sparkType.catalogString}, Found: $parquetType"
throw new QueryExecutionException(message, exception)
}

private def isLongDecimal(dt: DataType): Boolean =
dt match {
case d: DecimalType => d.precision == 20 && d.scale == 0
case _ => false
}

// TODO: After we deprecate Spark 3.1, fetch decimal meta with DecimalLogicalTypeAnnotation
@scala.annotation.nowarn("msg=method getDecimalMetadata in class PrimitiveType is deprecated")
private def canReadAsIntDecimal(pt: PrimitiveType, dt: DataType) = {
DecimalType.is32BitDecimalType(dt) && isDecimalTypeMatched(pt.getDecimalMetadata, dt)
}

// TODO: After we deprecate Spark 3.1, fetch decimal meta with DecimalLogicalTypeAnnotation
@scala.annotation.nowarn("msg=method getDecimalMetadata in class PrimitiveType is deprecated")
private def canReadAsLongDecimal(pt: PrimitiveType, dt: DataType): Boolean = {
DecimalType.is64BitDecimalType(dt) && isDecimalTypeMatched(pt.getDecimalMetadata, dt)
}

// TODO: After we deprecate Spark 3.1, fetch decimal meta with DecimalLogicalTypeAnnotation
@scala.annotation.nowarn("msg=method getDecimalMetadata in class PrimitiveType is deprecated")
private def canReadAsBinaryDecimal(pt: PrimitiveType, dt: DataType): Boolean = {
DecimalType.isByteArrayDecimalType(dt) && isDecimalTypeMatched(pt.getDecimalMetadata, dt)
}

// TODO: After we deprecate Spark 3.1, fetch decimal meta with DecimalLogicalTypeAnnotation
@scala.annotation.nowarn("msg=class DecimalMetadata in package schema is deprecated")
private def isDecimalTypeMatched(metadata: DecimalMetadata,
sparkType: DataType): Boolean = {
if (metadata == null) {
false
} else {
val dt = sparkType.asInstanceOf[DecimalType]
metadata.getPrecision <= dt.precision && metadata.getScale == dt.scale
}
}
}

/**
Expand Down