Skip to content

Commit

Permalink
Fix orc read data mess up for the schema can't be pruned (#3062)
Browse files Browse the repository at this point in the history
This PR first fixes the orc data read mess up for the schema which
can't be pruned, and then add the dis-order read schema unit tests
for both Parquet and Orc.

Signed-off-by: Bobby Wang <wbo4958@gmail.com>
  • Loading branch information
wbo4958 authored Jul 29, 2021
1 parent 0837e63 commit c4b4ae6
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ trait OrcCommonFunctions extends OrcCodecWritingHelper {
protected def buildReaderSchema(ctx: OrcPartitionReaderContext): TypeDescription = {
if (ctx.requestedMapping.isDefined) {
// filter top-level schema based on requested mapping
val orcSchema = ctx.fileSchema
val orcSchema = ctx.updatedReadSchema
val orcSchemaNames = orcSchema.getFieldNames
val orcSchemaChildren = orcSchema.getChildren
val readerSchema = TypeDescription.createStruct()
Expand Down
Binary file not shown.
Binary file added tests/src/test/resources/schema-can-prune.orc
Binary file not shown.
Binary file added tests/src/test/resources/schema-cant-prune.orc
Binary file not shown.
65 changes: 64 additions & 1 deletion tests/src/test/scala/com/nvidia/spark/rapids/OrcScanSuite.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION.
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,6 +18,7 @@ package com.nvidia.spark.rapids

import org.apache.spark.SparkConf
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructField, StructType}

class OrcScanSuite extends SparkQueryCompareTestSuite {

Expand All @@ -33,4 +34,66 @@ class OrcScanSuite extends SparkQueryCompareTestSuite {

testSparkResultsAreEqual("Test ORC count chunked by bytes", fileSplitsOrc,
new SparkConf().set(RapidsConf.MAX_READER_BATCH_SIZE_BYTES.key, "100"))(frameCount)

testSparkResultsAreEqual("schema-can-prune dis-order read schema",
frameFromOrcWithSchema("schema-can-prune.orc", StructType(Seq(
StructField("c2_string", StringType),
StructField("c3_long", LongType),
StructField("c1_int", IntegerType))))) { frame => frame }

testSparkResultsAreEqual("schema-can-prune dis-order read schema 1",
frameFromOrcWithSchema("schema-can-prune.orc", StructType(Seq(
StructField("c2_string", StringType),
StructField("c1_int", IntegerType),
StructField("c3_long", LongType))))) { frame => frame }

testSparkResultsAreEqual("schema-can-prune dis-order read schema 2",
frameFromOrcWithSchema("schema-can-prune.orc", StructType(Seq(
StructField("c3_long", LongType),
StructField("c2_string", StringType),
StructField("c1_int", IntegerType))))) { frame => frame }

testSparkResultsAreEqual("schema-can-prune dis-order read schema 3",
frameFromOrcWithSchema("schema-can-prune.orc", StructType(Seq(
StructField("c3_long", LongType),
StructField("c2_string", StringType))))) { frame => frame }

testSparkResultsAreEqual("schema-can-prune dis-order read schema 4",
frameFromOrcWithSchema("schema-can-prune.orc", StructType(Seq(
StructField("c2_string", StringType),
StructField("c1_int", IntegerType))))) { frame => frame }

testSparkResultsAreEqual("schema-can-prune dis-order read schema 5",
frameFromOrcWithSchema("schema-can-prune.orc", StructType(Seq(
StructField("c3_long", LongType),
StructField("c1_int", IntegerType))))) { frame => frame }

/**
* We can't compare the results from CPU and GPU, since CPU will get in-correct result
* see https://github.com/NVIDIA/spark-rapids/issues/3060
*/
test("schema can't be pruned") {
withGpuSparkSession( spark => {
val df = frameFromOrcWithSchema("schema-cant-prune.orc",
StructType(Seq(
StructField("_col2", StringType),
StructField("_col3", LongType),
StructField("_col1", IntegerType))))(spark)
val ret = df.collect()
assert(ret(0).getString(0) === "hello")
assert(ret(0).getLong(1) === 2021)
assert(ret(0).getInt(2) === 1)

val df1 = frameFromOrcWithSchema("schema-cant-prune.orc",
StructType(Seq(
StructField("_col3", LongType),
StructField("_col1", IntegerType),
StructField("_col2", StringType))))(spark)
val ret1 = df1.collect()
assert(ret1(0).getLong(0) === 2021)
assert(ret1(0).getInt(1) === 1)
assert(ret1(0).getString(2) === "hello")
})
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION.
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -22,6 +22,7 @@ import java.nio.file.Files
import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructField, StructType}

class ParquetScanSuite extends SparkQueryCompareTestSuite {
private val fileSplitsParquet = frameFromParquet("file-splits.parquet")
Expand Down Expand Up @@ -56,4 +57,37 @@ class ParquetScanSuite extends SparkQueryCompareTestSuite {
frameFromParquet("decimal-test-legacy.parquet")) {
frame => frame.select(col("*"))
}

testSparkResultsAreEqual("parquet dis-order read schema",
frameFromParquetWithSchema("disorder-read-schema.parquet", StructType(Seq(
StructField("c2_string", StringType),
StructField("c3_long", LongType),
StructField("c1_int", IntegerType))))) { frame => frame }

testSparkResultsAreEqual("parquet dis-order read schema 1",
frameFromParquetWithSchema("disorder-read-schema.parquet", StructType(Seq(
StructField("c2_string", StringType),
StructField("c1_int", IntegerType),
StructField("c3_long", LongType))))) { frame => frame }

testSparkResultsAreEqual("parquet dis-order read schema 2",
frameFromParquetWithSchema("disorder-read-schema.parquet", StructType(Seq(
StructField("c3_long", LongType),
StructField("c2_string", StringType),
StructField("c1_int", IntegerType))))) { frame => frame }

testSparkResultsAreEqual("parquet dis-order read schema 3",
frameFromParquetWithSchema("disorder-read-schema.parquet", StructType(Seq(
StructField("c3_long", LongType),
StructField("c2_string", StringType))))) { frame => frame }

testSparkResultsAreEqual("parquet dis-order read schema 4",
frameFromParquetWithSchema("disorder-read-schema.parquet", StructType(Seq(
StructField("c2_string", StringType),
StructField("c1_int", IntegerType))))) { frame => frame }

testSparkResultsAreEqual("parquet dis-order read schema 5",
frameFromParquetWithSchema("disorder-read-schema.parquet", StructType(Seq(
StructField("c3_long", LongType),
StructField("c1_int", IntegerType))))) { frame => frame }
}
Original file line number Diff line number Diff line change
Expand Up @@ -1788,11 +1788,22 @@ trait SparkQueryCompareTestSuite extends FunSuite with Arm {
s: SparkSession => s.read.parquet(path)
}

def frameFromParquetWithSchema(filename: String, schema: StructType):
SparkSession => DataFrame = {
val path = TestResourceFinder.getResourcePath(filename)
s: SparkSession => s.read.schema(schema).parquet(path)
}

def frameFromOrc(filename: String): SparkSession => DataFrame = {
val path = TestResourceFinder.getResourcePath(filename)
s: SparkSession => s.read.orc(path)
}

def frameFromOrcWithSchema(filename: String, schema: StructType): SparkSession => DataFrame = {
val path = TestResourceFinder.getResourcePath(filename)
s: SparkSession => s.read.schema(schema).orc(path)
}

def frameFromOrcNonNullableColumns(filename: String): SparkSession => DataFrame = {
val path = TestResourceFinder.getResourcePath(filename)
s: SparkSession => {
Expand Down

0 comments on commit c4b4ae6

Please sign in to comment.