Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check schema compatibility when building parquet readers #5434

Merged
merged 13 commits into from
May 12, 2022
70 changes: 68 additions & 2 deletions integration_tests/src/main/python/parquet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

import pytest

from asserts import assert_cpu_and_gpu_are_equal_collect_with_capture, assert_gpu_and_cpu_are_equal_collect, assert_gpu_fallback_collect, assert_gpu_and_cpu_are_equal_sql
from asserts import assert_cpu_and_gpu_are_equal_collect_with_capture, assert_gpu_and_cpu_are_equal_collect, \
assert_gpu_fallback_collect, assert_gpu_and_cpu_are_equal_sql, assert_gpu_and_cpu_error, assert_py4j_exception
from data_gen import *
from marks import *
from pyspark.sql.types import *
Expand Down Expand Up @@ -820,6 +821,26 @@ def test_parquet_push_down_on_interval_type(spark_tmp_path):
"select * from testData where _c1 > interval '10 0:0:0' day to second")


def test_parquet_check_schema_compatibility(spark_tmp_path):
data_path = spark_tmp_path + '/PARQUET_DATA'
gen_list = [('int', int_gen), ('long', long_gen), ('dec32', decimal_gen_32bit)]
with_cpu_session(lambda spark: gen_df(spark, gen_list).coalesce(1).write.parquet(data_path))

read_int_as_long = StructType(
[StructField('long', LongType()), StructField('int', LongType())])
assert_gpu_and_cpu_error(
lambda spark: spark.read.schema(read_int_as_long).parquet(data_path).collect(),
conf={},
error_message='Parquet column cannot be converted in')

read_dec32_as_dec64 = StructType(
[StructField('int', IntegerType()), StructField('dec32', DecimalType(15, 10))])
assert_gpu_and_cpu_error(
lambda spark: spark.read.schema(read_dec32_as_dec64).parquet(data_path).collect(),
conf={},
error_message='Parquet column cannot be converted in')


def test_parquet_read_case_insensitivity(spark_tmp_path):
gen_list = [('one', int_gen), ('tWo', byte_gen), ('THREE', boolean_gen)]
data_path = spark_tmp_path + '/PARQUET_DATA'
Expand All @@ -829,4 +850,49 @@ def test_parquet_read_case_insensitivity(spark_tmp_path):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: spark.read.parquet(data_path).select('one', 'two', 'three'),
{'spark.sql.caseSensitive': 'false'}
)
)


# For nested types, GPU throws incompatible exception with a different message from CPU.
def test_parquet_check_schema_compatibility_nested_types(spark_tmp_path):
data_path = spark_tmp_path + '/PARQUET_DATA'
gen_list = [('array_long', ArrayGen(long_gen)),
('array_array_int', ArrayGen(ArrayGen(int_gen))),
('struct_float', StructGen([('f', float_gen), ('d', double_gen)])),
('struct_array_int', StructGen([('a', ArrayGen(int_gen))])),
('map', map_string_string_gen[0])]
with_cpu_session(lambda spark: gen_df(spark, gen_list).coalesce(1).write.parquet(data_path))

read_array_long_as_int = StructType([StructField('array_long', ArrayType(IntegerType()))])
assert_py4j_exception(
lambda: with_gpu_session(
lambda spark: spark.read.schema(read_array_long_as_int).parquet(data_path).collect()),
error_message='Parquet column cannot be converted in')

read_arr_arr_int_as_long = StructType(
[StructField('array_array_int', ArrayType(ArrayType(LongType())))])
assert_py4j_exception(
lambda: with_gpu_session(
lambda spark: spark.read.schema(read_arr_arr_int_as_long).parquet(data_path).collect()),
error_message='Parquet column cannot be converted in')

read_struct_flt_as_dbl = StructType([StructField(
'struct_float', StructType([StructField('f', DoubleType())]))])
assert_py4j_exception(
lambda: with_gpu_session(
lambda spark: spark.read.schema(read_struct_flt_as_dbl).parquet(data_path).collect()),
error_message='Parquet column cannot be converted in')

read_struct_arr_int_as_long = StructType([StructField(
'struct_array_int', StructType([StructField('a', ArrayType(LongType()))]))])
assert_py4j_exception(
lambda: with_gpu_session(
lambda spark: spark.read.schema(read_struct_arr_int_as_long).parquet(data_path).collect()),
error_message='Parquet column cannot be converted in')

read_map_str_str_as_str_int = StructType([StructField(
'map', MapType(StringType(), IntegerType()))])
assert_py4j_exception(
lambda: with_gpu_session(
lambda spark: spark.read.schema(read_map_str_str_as_str_int).parquet(data_path).collect()),
error_message='Parquet column cannot be converted in')
178 changes: 176 additions & 2 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ import org.apache.parquet.filter2.predicate.FilterApi
import org.apache.parquet.format.converter.ParquetMetadataConverter
import org.apache.parquet.hadoop.{ParquetFileReader, ParquetInputFormat}
import org.apache.parquet.hadoop.metadata._
import org.apache.parquet.schema.{GroupType, MessageType, OriginalType, PrimitiveType, Type, Types}
import org.apache.parquet.schema.{DecimalMetadata, GroupType, MessageType, OriginalType, PrimitiveType, Type, Types}
import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName

import org.apache.spark.TaskContext
Expand All @@ -55,7 +55,7 @@ import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
import org.apache.spark.sql.execution.QueryExecutionException
import org.apache.spark.sql.execution.datasources.{DataSourceUtils, PartitionedFile, PartitioningAwareFileIndex}
import org.apache.spark.sql.execution.datasources.{DataSourceUtils, PartitionedFile, PartitioningAwareFileIndex, SchemaColumnConvertNotSupportedException}
import org.apache.spark.sql.execution.datasources.parquet.ParquetReadSupport
import org.apache.spark.sql.execution.datasources.v2.{FilePartitionReaderFactory, FileScan}
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
Expand Down Expand Up @@ -411,6 +411,9 @@ private case class GpuParquetFileFilterHandler(@transient sqlConf: SQLConf) exte
withResource(new NvtxRange("clipSchema", NvtxColor.DARK_GREEN)) { _ =>
val clippedSchemaTmp = ParquetReadSupport.clipParquetSchema(fileSchema, readDataSchema,
isCaseSensitive)
// Check if the read schema is compatible with the file schema.
checkSchemaCompat(clippedSchemaTmp, readDataSchema,
(t: Type, d: DataType) => throwTypeIncompatibleError(t, d, file.filePath))
// ParquetReadSupport.clipParquetSchema does most of what we want, but it includes
// everything in readDataSchema, even if it is not in fileSchema we want to remove those
// for our own purposes
Expand All @@ -428,6 +431,177 @@ private case class GpuParquetFileFilterHandler(@transient sqlConf: SQLConf) exte
hasInt96Timestamps)
}
}

/**
* Recursively check if the read schema is compatible with the file schema. The errorCallback
* will be invoked to throw an exception once any incompatible type pairs are found.
*
* The function assumes all elements in read schema are included in file schema, so please
* run this check after clipping read schema upon file schema.
*
* The function only accepts top-level schemas, which means structures of root columns. Based
* on this assumption, it can infer root types from input schemas.
*
* @param fileType input file's Parquet schema
* @param readType spark type read from Parquet file
* @param errorCallback call back function to throw exception if type mismatch
* @param rootFileType file type of each root column
* @param rootReadType read type of each root column
*/
private def checkSchemaCompat(fileType: Type,
readType: DataType,
errorCallback: (Type, DataType) => Unit,
rootFileType: Option[Type] = None,
rootReadType: Option[DataType] = None): Unit = {
readType match {
case struct: StructType =>
// case insensitive field matching
jlowe marked this conversation as resolved.
Show resolved Hide resolved
val fileFieldMap = fileType.asGroupType().getFields.asScala
.map(f => f.getName.toLowerCase -> f).toMap
struct.fields.foreach { f =>
val curFile = fileFieldMap(f.name.toLowerCase)
checkSchemaCompat(curFile,
f.dataType,
errorCallback,
// Record root types for each column, so as to throw a readable exception
// over nested types.
Some(rootFileType.getOrElse(curFile)),
Some(rootReadType.getOrElse(f.dataType)))
}

case array: ArrayType =>
val fileChild = fileType.asGroupType().getType(0)
.asGroupType().getType(0)
checkSchemaCompat(fileChild, array.elementType, errorCallback,
rootFileType, rootReadType)

case map: MapType =>
val parquetMap = fileType.asGroupType().getType(0).asGroupType()
val parquetMapKey = parquetMap.getType(0)
val parquetMapValue = parquetMap.getType(1)
checkSchemaCompat(parquetMapKey, map.keyType, errorCallback,
rootFileType, rootReadType)
checkSchemaCompat(parquetMapValue, map.valueType, errorCallback,
rootFileType, rootReadType)

case dt =>
checkPrimitiveCompat(fileType.asPrimitiveType(),
dt,
() => errorCallback(rootFileType.get, rootReadType.get))
}
}

/**
* Check the compatibility over primitive types. This function refers to the `getUpdater` method
* of [[org.apache.spark.sql.execution.datasources.parquet.ParquetVectorUpdaterFactory]].
jlowe marked this conversation as resolved.
Show resolved Hide resolved
*
* To avoid unnecessary pattern matching, this function is designed to return or throw ASAP.
*
* This function uses some deprecated Parquet APIs, because Spark 3.1 is relied on parquet-mr
* of an older version.
*/
@scala.annotation.nowarn("msg=method getDecimalMetadata in class PrimitiveType is deprecated")
private def checkPrimitiveCompat(pt: PrimitiveType,
dt: DataType,
errorCallback: () => Unit): Unit = {
pt.getPrimitiveTypeName match {
case PrimitiveTypeName.BOOLEAN if dt == DataTypes.BooleanType =>
case PrimitiveTypeName.BOOLEAN =>
errorCallback()
jlowe marked this conversation as resolved.
Show resolved Hide resolved

// TODO: After we deprecate Spark 3.1, add YearMonthIntervalType
case PrimitiveTypeName.INT32 if dt == DataTypes.IntegerType =>
case PrimitiveTypeName.INT32 if dt == DataTypes.ByteType =>
case PrimitiveTypeName.INT32 if dt == DataTypes.ShortType =>
jlowe marked this conversation as resolved.
Show resolved Hide resolved
case PrimitiveTypeName.INT32 if dt == DataTypes.DateType =>
// TODO: After we deprecate Spark 3.1, replace OriginalType with LogicalTypeAnnotation
case PrimitiveTypeName.INT32 if dt == DataTypes.LongType &&
pt.getOriginalType == OriginalType.UINT_32 =>
jlowe marked this conversation as resolved.
Show resolved Hide resolved
// TODO: After we deprecate Spark 3.1, fetch meta data from LogicalTypeAnnotation
case PrimitiveTypeName.INT32 if DecimalType.is32BitDecimalType(dt) &&
isDecimalTypeMatched(pt.getDecimalMetadata, dt) =>
case PrimitiveTypeName.INT32 =>
errorCallback()

// TODO: After we deprecate Spark 3.1, add DayTimeIntervalType
case PrimitiveTypeName.INT64 if dt == DataTypes.LongType =>
// TODO: After we deprecate Spark 3.1, replace OriginalType with LogicalTypeAnnotation
case PrimitiveTypeName.INT64 if pt.getOriginalType == OriginalType.UINT_64 &&
DecimalType.is64BitDecimalType(dt) && {
val decType = dt.asInstanceOf[DecimalType]
decType.precision == 20 && decType.scale == 0
} =>
case PrimitiveTypeName.INT64 if pt.getOriginalType == OriginalType.TIMESTAMP_MICROS ||
pt.getOriginalType == OriginalType.TIMESTAMP_MILLIS =>
// TODO: After we deprecate Spark 3.1, fetch meta data from LogicalTypeAnnotation
case PrimitiveTypeName.INT64 if DecimalType.is64BitDecimalType(dt) &&
isDecimalTypeMatched(pt.getDecimalMetadata, dt) =>
case PrimitiveTypeName.INT64 =>
errorCallback()

case PrimitiveTypeName.FLOAT if dt == DataTypes.FloatType =>
case PrimitiveTypeName.FLOAT =>
errorCallback()

case PrimitiveTypeName.DOUBLE if dt == DataTypes.DoubleType =>
case PrimitiveTypeName.DOUBLE =>
errorCallback()

case PrimitiveTypeName.INT96 if dt == DataTypes.TimestampType =>
case PrimitiveTypeName.INT96 =>
errorCallback()

case PrimitiveTypeName.BINARY if dt == DataTypes.StringType =>
case PrimitiveTypeName.BINARY if dt == DataTypes.BinaryType =>
// TODO: After we deprecate Spark 3.1, fetch meta data from LogicalTypeAnnotation
case PrimitiveTypeName.BINARY if DecimalType.isByteArrayDecimalType(dt) &&
isDecimalTypeMatched(pt.getDecimalMetadata, dt) =>
case PrimitiveTypeName.BINARY =>
errorCallback()

// TODO: After we deprecate Spark 3.1, fetch meta data from LogicalTypeAnnotation
case PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY if DecimalType.is32BitDecimalType(dt) &&
isDecimalTypeMatched(pt.getDecimalMetadata, dt) =>
case PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY if DecimalType.is64BitDecimalType(dt) &&
isDecimalTypeMatched(pt.getDecimalMetadata, dt) =>
case PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY if DecimalType.isByteArrayDecimalType(dt) &&
isDecimalTypeMatched(pt.getDecimalMetadata, dt) =>
case PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY =>
errorCallback()

// If we get here, it means the combination of Spark and Parquet type is invalid or not
// supported.
case _ =>
errorCallback()
}
}

private def throwTypeIncompatibleError(parquetType: Type,
sparkType: DataType,
filePath: String): Unit = {
val exception = new SchemaColumnConvertNotSupportedException(
parquetType.getName,
parquetType.toString,
sparkType.catalogString)

// A copy of QueryExecutionErrors.unsupportedSchemaColumnConvertError introduced in 3.2+
// TODO: replace with unsupportedSchemaColumnConvertError after we deprecate Spark 3.1
val message = "Parquet column cannot be converted in " +
s"file $filePath. Column: ${parquetType.getName}, " +
s"Expected: ${sparkType.catalogString}, Found: $parquetType"
throw new QueryExecutionException(message, exception)
}

@scala.annotation.nowarn("msg=class DecimalMetadata in package schema is deprecated")
private def isDecimalTypeMatched(metadata: DecimalMetadata,
sparkType: DataType): Boolean = {
if (metadata == null) {
false
} else {
val dt = sparkType.asInstanceOf[DecimalType]
metadata.getPrecision <= dt.precision && metadata.getScale == dt.scale
}
}
}

/**
Expand Down