Skip to content

Commit

Permalink
Re-enable the struct support for the orc reader.
Browse files Browse the repository at this point in the history
Also add tests for the nested predicate pushdown, and
the support for nested column pruning.

Relevant PRs:
  NVIDIA#3079
  NVIDIA#2887

Signed-off-by: Firestarman <firestarmanllc@gmail.com>
  • Loading branch information
firestarman committed Aug 11, 2021
1 parent 92e0ecf commit 8761bbc
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 13 deletions.
4 changes: 2 additions & 2 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -21281,9 +21281,9 @@ dates or timestamps, or for a lack of type coercion support.
<td> </td>
<td><b>NS</b></td>
<td> </td>
<td><em>PS<br/>UTC is only supported TZ for nested TIMESTAMP;<br/>missing nested DECIMAL, BINARY, MAP, STRUCT, UDT</em></td>
<td><b>NS</b></td>
<td><em>PS<br/>UTC is only supported TZ for nested TIMESTAMP;<br/>missing nested DECIMAL, BINARY, MAP, UDT</em></td>
<td><b>NS</b></td>
<td><em>PS<br/>UTC is only supported TZ for nested TIMESTAMP;<br/>missing nested DECIMAL, BINARY, MAP, UDT</em></td>
<td><b>NS</b></td>
</tr>
<tr>
Expand Down
43 changes: 40 additions & 3 deletions integration_tests/src/main/python/orc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from marks import *
from pyspark.sql.types import *
from spark_session import with_cpu_session, with_spark_session
from parquet_test import _nested_pruning_schemas

def read_orc_df(data_path):
return lambda spark : spark.read.orc(data_path)
Expand Down Expand Up @@ -50,13 +51,28 @@ def test_basic_read(std_input_path, name, read_func, v1_enabled_list, orc_impl,
string_gen, boolean_gen, DateGen(start=date(1590, 1, 1)),
TimestampGen(start=datetime(1590, 1, 1, tzinfo=timezone.utc))]

orc_basic_struct_gen = StructGen([['child'+str(ind), sub_gen] for ind, sub_gen in enumerate(orc_basic_gens)])

# Some array gens, but not all because of nesting
orc_array_gens_sample = [ArrayGen(sub_gen) for sub_gen in orc_basic_gens] + [
ArrayGen(ArrayGen(short_gen, max_length=10), max_length=10),
ArrayGen(ArrayGen(string_gen, max_length=10), max_length=10)]
ArrayGen(ArrayGen(string_gen, max_length=10), max_length=10),
ArrayGen(StructGen([['child0', byte_gen], ['child1', string_gen], ['child2', float_gen]]))]

# Some struct gens, but not all because of nesting.
# No empty struct gen because it leads to an error as below.
# '''
# E pyspark.sql.utils.AnalysisException:
# E Datasource does not support writing empty or nested empty schemas.
# E Please make sure the data schema has at least one or more column(s).
# '''
orc_struct_gens_sample = [orc_basic_struct_gen,
StructGen([['child0', byte_gen], ['child1', orc_basic_struct_gen]]),
StructGen([['child0', ArrayGen(short_gen)], ['child1', double_gen]])]

orc_gens_list = [orc_basic_gens,
orc_array_gens_sample,
orc_struct_gens_sample,
pytest.param([date_gen], marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/131')),
pytest.param([timestamp_gen], marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/131'))]

Expand Down Expand Up @@ -110,15 +126,18 @@ def test_read_round_trip(spark_tmp_path, orc_gens, read_func, reader_confs, v1_e
@pytest.mark.parametrize('reader_confs', reader_opt_confs, ids=idfn)
def test_pred_push_round_trip(spark_tmp_path, orc_gen, read_func, v1_enabled_list, reader_confs):
data_path = spark_tmp_path + '/ORC_DATA'
gen_list = [('a', RepeatSeqGen(orc_gen, 100)), ('b', orc_gen)]
# Append two struct columns to verify nested predicate pushdown.
gen_list = [('a', RepeatSeqGen(orc_gen, 100)), ('b', orc_gen),
('s1', StructGen([['sa', orc_gen]])),
('s2', StructGen([['sa', StructGen([['ssa', orc_gen]])]]))]
s0 = gen_scalar(orc_gen, force_no_nulls=True)
with_cpu_session(
lambda spark : gen_df(spark, gen_list).orderBy('a').write.orc(data_path))
all_confs = reader_confs.copy()
all_confs.update({'spark.sql.sources.useV1SourceList': v1_enabled_list})
rf = read_func(data_path)
assert_gpu_and_cpu_are_equal_collect(
lambda spark: rf(spark).select(f.col('a') >= s0),
lambda spark: rf(spark).select(f.col('a') >= s0, f.col('s1.sa') >= s0, f.col('s2.sa.ssa') >= s0),
conf=all_confs)

orc_compress_options = ['none', 'uncompressed', 'snappy', 'zlib']
Expand Down Expand Up @@ -314,3 +333,21 @@ def test_missing_column_names_filter(spark_tmp_table_factory, reader_confs):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : spark.sql("SELECT _col3,_col2 FROM {} WHERE _col2 = '155'".format(table_name)),
all_confs)


@pytest.mark.parametrize('data_gen,read_schema', _nested_pruning_schemas, ids=idfn)
@pytest.mark.parametrize('reader_confs', reader_opt_confs, ids=idfn)
@pytest.mark.parametrize('v1_enabled_list', ["", "orc"])
@pytest.mark.parametrize('nested_enabled', ["true", "false"])
def test_read_nested_pruning(spark_tmp_path, data_gen, read_schema, reader_confs, v1_enabled_list, nested_enabled):
data_path = spark_tmp_path + '/ORC_DATA'
with_cpu_session(
lambda spark : gen_df(spark, data_gen).write.orc(data_path))
all_confs = reader_confs.copy()
all_confs.update({'spark.sql.sources.useV1SourceList': v1_enabled_list,
'spark.sql.optimizer.nestedSchemaPruning.enabled': nested_enabled})
# This is a hack to get the type in a slightly less verbose way
rs = StructGen(read_schema, nullable=False).data_type
assert_gpu_and_cpu_are_equal_collect(
lambda spark : spark.read.schema(rs).orc(data_path),
conf=all_confs)
39 changes: 34 additions & 5 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1119,7 +1119,7 @@ private case class GpuOrcFileFilterHandler(
* @param fileSchema input file's ORC schema
* @param readSchema ORC schema for what will be read
* @param isCaseAware true if field names are case-sensitive
* @return read schema mapped to the file's field names
* @return read schema if check passes.
*/
private def checkSchemaCompatibility(
fileSchema: TypeDescription,
Expand All @@ -1136,19 +1136,48 @@ private case class GpuOrcFileFilterHandler(

val readerFieldNames = readSchema.getFieldNames.asScala
val readerChildren = readSchema.getChildren.asScala
val newReadSchema = TypeDescription.createStruct()
readerFieldNames.zip(readerChildren).foreach { case (readField, readType) =>
val (fileType, fileFieldName) = fileTypesMap.getOrElse(readField, (null, null))
if (readType != fileType) {
// When column pruning is enabled, the readType is not always equal to the fileType,
// may be part of the fileType. e.g.
// read type: struct<c_1:string>
// file type: struct<c_1:string,c_2:bigint,c_3:smallint>
if (!isSchemaCompatible(fileType, readType)) {
throw new QueryExecutionException("Incompatible schemas for ORC file" +
s" at ${partFile.filePath}\n" +
s" file schema: $fileSchema\n" +
s" read schema: $readSchema")
}
newReadSchema.addField(fileFieldName, fileType)
}
// To support nested column pruning, the original read schema (pruned) should be
// returned, instead of creating a new schema from the children of the file schema,
// who may contain more nested columns than read schema, causing mismatch between the
// pruned data and the pruned schema.
readSchema
}

newReadSchema
/**
* The read schema is compatible with the file schema only when
* 1) They are equal to each other
* 2) The read schema is part of the file schema for struct types.
*
* @param fileSchema input file's ORC schema
* @param readSchema ORC schema for what will be read
* @return true if they are compatible, otherwise false
*/
private def isSchemaCompatible(
fileSchema: TypeDescription,
readSchema: TypeDescription): Boolean = {
fileSchema == readSchema ||
fileSchema != null && readSchema != null &&
fileSchema.getCategory == readSchema.getCategory && {
if (readSchema.getChildren != null) {
readSchema.getChildren.asScala.forall(rc =>
fileSchema.getChildren.asScala.exists(fc => isSchemaCompatible(fc, rc)))
} else {
false
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -769,7 +769,7 @@ object GpuOverrides {
sparkSig = (TypeSig.atomics + TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP +
TypeSig.UDT).nested())),
(OrcFormatType, FileFormatChecks(
cudfRead = (TypeSig.commonCudfTypes + TypeSig.ARRAY).nested(),
cudfRead = (TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT).nested(),
cudfWrite = TypeSig.commonCudfTypes,
sparkSig = (TypeSig.atomics + TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP +
TypeSig.UDT).nested())))
Expand Down
2 changes: 1 addition & 1 deletion tools/src/main/resources/supportedDataSource.csv
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Format,Direction,BOOLEAN,BYTE,SHORT,INT,LONG,FLOAT,DOUBLE,DATE,TIMESTAMP,STRING,DECIMAL,NULL,BINARY,CALENDAR,ARRAY,MAP,STRUCT,UDT
CSV,read,CO,CO,CO,CO,CO,CO,CO,CO,CO,S,CO,NA,NS,NA,NA,NA,NA,NA
ORC,read,S,S,S,S,S,S,S,S,PS,S,CO,NA,NS,NA,PS,NS,NS,NS
ORC,read,S,S,S,S,S,S,S,S,PS,S,CO,NA,NS,NA,PS,NS,PS,NS
Parquet,read,S,S,S,S,S,S,S,S,PS,S,CO,NA,NS,NA,PS,PS,PS,NS
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
App Name,App ID,Score,Potential Problems,SQL DF Duration,SQL Dataframe Task Duration,App Duration,Executor CPU Time Percent,App Duration Estimated,SQL Duration with Potential Problems,SQL Ids with Failures,Read Score Percent,Read File Format Score,Unsupported Read File Formats and Types
Spark shell,local-1626104300434,1322.1,DECIMAL,2429,1469,131104,88.35,false,160,"",20,50.0,Parquet[decimal];ORC[map:struct:decimal]
Spark shell,local-1626104300434,1322.1,DECIMAL,2429,1469,131104,88.35,false,160,"",20,50.0,Parquet[decimal];ORC[map:decimal]

0 comments on commit 8761bbc

Please sign in to comment.