Skip to content

Commit

Permalink
[SPARK-32646][SQL][3.0][TEST-HADOOP2.7][TEST-HIVE1.2] ORC predicate p…
Browse files Browse the repository at this point in the history
…ushdown should work with case-insensitive analysis

### What changes were proposed in this pull request?

This PR proposes to fix ORC predicate pushdown under case-insensitive analysis case. The field names in pushed down predicates don't need to match in exact letter case with physical field names in ORC files, if we enable case-insensitive analysis.

### Why are the changes needed?

Currently ORC predicate pushdown doesn't work with case-insensitive analysis. A predicate "a < 0" cannot pushdown to ORC file with field name "A" under case-insensitive analysis.

But Parquet predicate pushdown works with this case. We should make ORC predicate pushdown work with case-insensitive analysis too.

### Does this PR introduce _any_ user-facing change?

Yes, after this PR, under case-insensitive analysis, ORC predicate pushdown will work.

### How was this patch tested?

Unit tests.

Closes #29513 from viirya/fix-orc-pushdown-3.0.

Authored-by: Liang-Chi Hsieh <viirya@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
viirya authored and cloud-fan committed Aug 25, 2020
1 parent 82aef3e commit 6c88d7c
Show file tree
Hide file tree
Showing 10 changed files with 253 additions and 85 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,6 @@ class OrcFileFormat
filters: Seq[Filter],
options: Map[String, String],
hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = {
if (sparkSession.sessionState.conf.orcFilterPushDown) {
OrcFilters.createFilter(dataSchema, filters).foreach { f =>
OrcInputFormat.setSearchArgument(hadoopConf, f, dataSchema.fieldNames)
}
}

val resultSchema = StructType(requiredSchema.fields ++ partitionSchema.fields)
val sqlConf = sparkSession.sessionState.conf
Expand All @@ -169,6 +164,8 @@ class OrcFileFormat
val broadcastedConf =
sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))
val isCaseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
val orcFilterPushDown = sparkSession.sessionState.conf.orcFilterPushDown
val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles

(file: PartitionedFile) => {
val conf = broadcastedConf.value.value
Expand All @@ -186,6 +183,15 @@ class OrcFileFormat
if (resultedColPruneInfo.isEmpty) {
Iterator.empty
} else {
// ORC predicate pushdown
if (orcFilterPushDown) {
OrcUtils.readCatalystSchema(filePath, conf, ignoreCorruptFiles).map { fileSchema =>
OrcFilters.createFilter(fileSchema, filters).foreach { f =>
OrcInputFormat.setSearchArgument(conf, f, fileSchema.fieldNames)
}
}
}

val (requestedColIds, canPruneCols) = resultedColPruneInfo.get
val resultSchemaString = OrcUtils.orcResultSchemaString(canPruneCols,
dataSchema, resultSchema, partitionSchema, conf)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,45 @@

package org.apache.spark.sql.execution.datasources.orc

import java.util.Locale

import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.quoteIfNeeded
import org.apache.spark.sql.sources.{And, Filter}
import org.apache.spark.sql.types.{AtomicType, BinaryType, DataType}
import org.apache.spark.sql.types.{AtomicType, BinaryType, DataType, StructType}

/**
* Methods that can be shared when upgrading the built-in Hive.
*/
trait OrcFiltersBase {

case class OrcPrimitiveField(fieldName: String, fieldType: DataType)

protected[sql] def getDataTypeMap(
schema: StructType,
caseSensitive: Boolean): Map[String, OrcPrimitiveField] = {
val fields = schema.flatMap { f =>
if (isSearchableType(f.dataType)) {
Some(quoteIfNeeded(f.name) -> OrcPrimitiveField(quoteIfNeeded(f.name), f.dataType))
} else {
None
}
}

if (caseSensitive) {
fields.toMap
} else {
// Don't consider ambiguity here, i.e. more than one field are matched in case insensitive
// mode, just skip pushdown for these fields, they will trigger Exception when reading,
// See: SPARK-25175.
val dedupPrimitiveFields = fields
.groupBy(_._1.toLowerCase(Locale.ROOT))
.filter(_._2.size == 1)
.mapValues(_.head._2)
CaseInsensitiveMap(dedupPrimitiveFields)
}
}

private[sql] def buildTree(filters: Seq[Filter]): Option[Filter] = {
filters match {
case Seq() => None
Expand All @@ -40,7 +71,7 @@ trait OrcFiltersBase {
* Return true if this is a searchable type in ORC.
* Both CharType and VarcharType are cleaned at AstBuilder.
*/
protected[sql] def isSearchableType(dataType: DataType) = dataType match {
private def isSearchableType(dataType: DataType) = dataType match {
case BinaryType => false
case _: AtomicType => true
case _ => false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,20 @@ object OrcUtils extends Logging {
}
}

def readCatalystSchema(
file: Path,
conf: Configuration,
ignoreCorruptFiles: Boolean): Option[StructType] = {
readSchema(file, conf, ignoreCorruptFiles) match {
case Some(schema) =>
Some(CatalystSqlParser.parseDataType(schema.toString).asInstanceOf[StructType])

case None =>
// Field names is empty or `FileFormatException` was thrown but ignoreCorruptFiles is true.
None
}
}

/**
* Reads ORC file schemas in multi-threaded manner, using native version of ORC.
* This is visible for testing.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader}
import org.apache.spark.sql.execution.datasources.PartitionedFile
import org.apache.spark.sql.execution.datasources.orc.{OrcColumnarBatchReader, OrcDeserializer, OrcUtils}
import org.apache.spark.sql.execution.datasources.orc.{OrcColumnarBatchReader, OrcDeserializer, OrcFilters, OrcUtils}
import org.apache.spark.sql.execution.datasources.v2._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.{AtomicType, StructType}
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.{SerializableConfiguration, Utils}
Expand All @@ -52,24 +53,39 @@ case class OrcPartitionReaderFactory(
broadcastedConf: Broadcast[SerializableConfiguration],
dataSchema: StructType,
readDataSchema: StructType,
partitionSchema: StructType) extends FilePartitionReaderFactory {
partitionSchema: StructType,
filters: Array[Filter]) extends FilePartitionReaderFactory {
private val resultSchema = StructType(readDataSchema.fields ++ partitionSchema.fields)
private val isCaseSensitive = sqlConf.caseSensitiveAnalysis
private val capacity = sqlConf.orcVectorizedReaderBatchSize
private val orcFilterPushDown = sqlConf.orcFilterPushDown
private val ignoreCorruptFiles = sqlConf.ignoreCorruptFiles

override def supportColumnarReads(partition: InputPartition): Boolean = {
sqlConf.orcVectorizedReaderEnabled && sqlConf.wholeStageEnabled &&
resultSchema.length <= sqlConf.wholeStageMaxNumFields &&
resultSchema.forall(_.dataType.isInstanceOf[AtomicType])
}

private def pushDownPredicates(filePath: Path, conf: Configuration): Unit = {
if (orcFilterPushDown) {
OrcUtils.readCatalystSchema(filePath, conf, ignoreCorruptFiles).map { fileSchema =>
OrcFilters.createFilter(fileSchema, filters).foreach { f =>
OrcInputFormat.setSearchArgument(conf, f, fileSchema.fieldNames)
}
}
}
}

override def buildReader(file: PartitionedFile): PartitionReader[InternalRow] = {
val conf = broadcastedConf.value.value

OrcConf.IS_SCHEMA_EVOLUTION_CASE_SENSITIVE.setBoolean(conf, isCaseSensitive)

val filePath = new Path(new URI(file.filePath))

pushDownPredicates(filePath, conf)

val fs = filePath.getFileSystem(conf)
val readerOptions = OrcFile.readerOptions(conf).filesystem(fs)
val resultedColPruneInfo =
Expand Down Expand Up @@ -116,6 +132,8 @@ case class OrcPartitionReaderFactory(

val filePath = new Path(new URI(file.filePath))

pushDownPredicates(filePath, conf)

val fs = filePath.getFileSystem(conf)
val readerOptions = OrcFile.readerOptions(conf).filesystem(fs)
val resultedColPruneInfo =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ case class OrcScan(
// The partition values are already truncated in `FileScan.partitions`.
// We should use `readPartitionSchema` as the partition schema here.
OrcPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf,
dataSchema, readDataSchema, readPartitionSchema)
dataSchema, readDataSchema, readPartitionSchema, pushedFilters)
}

override def equals(obj: Any): Boolean = obj match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ import scala.collection.JavaConverters._
import org.apache.orc.mapreduce.OrcInputFormat

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.quoteIfNeeded
import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownFilters}
import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
import org.apache.spark.sql.execution.datasources.orc.OrcFilters
import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
Expand Down Expand Up @@ -55,12 +55,7 @@ case class OrcScanBuilder(

override def pushFilters(filters: Array[Filter]): Array[Filter] = {
if (sparkSession.sessionState.conf.orcFilterPushDown) {
OrcFilters.createFilter(schema, filters).foreach { f =>
// The pushed filters will be set in `hadoopConf`. After that, we can simply use the
// changed `hadoopConf` in executors.
OrcInputFormat.setSearchArgument(hadoopConf, f, schema.fieldNames)
}
val dataTypeMap = schema.map(f => quoteIfNeeded(f.name) -> f.dataType).toMap
val dataTypeMap = OrcFilters.getDataTypeMap(schema, SQLConf.get.caseSensitiveAnalysis)
// TODO (SPARK-25557): ORC doesn't support nested predicate pushdown, so they are removed.
val newFilters = filters.filter(!_.containsNestedColumn)
_pushedFilters = OrcFilters.convertibleFilters(schema, dataTypeMap, newFilters).toArray
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.orc.storage.serde2.io.HiveDecimalWritable

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, localDateToDays, toJavaDate, toJavaTimestamp}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.quoteIfNeeded
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -68,7 +68,7 @@ private[sql] object OrcFilters extends OrcFiltersBase {
* Create ORC filter as a SearchArgument instance.
*/
def createFilter(schema: StructType, filters: Seq[Filter]): Option[SearchArgument] = {
val dataTypeMap = schema.map(f => quoteIfNeeded(f.name) -> f.dataType).toMap
val dataTypeMap = getDataTypeMap(schema, SQLConf.get.caseSensitiveAnalysis)
// Combines all convertible filters using `And` to produce a single conjunction
// TODO (SPARK-25557): ORC doesn't support nested predicate pushdown, so they are removed.
val newFilters = filters.filter(!_.containsNestedColumn)
Expand All @@ -83,7 +83,7 @@ private[sql] object OrcFilters extends OrcFiltersBase {

def convertibleFilters(
schema: StructType,
dataTypeMap: Map[String, DataType],
dataTypeMap: Map[String, OrcPrimitiveField],
filters: Seq[Filter]): Seq[Filter] = {
import org.apache.spark.sql.sources._

Expand Down Expand Up @@ -141,7 +141,7 @@ private[sql] object OrcFilters extends OrcFiltersBase {
/**
* Get PredicateLeafType which is corresponding to the given DataType.
*/
private def getPredicateLeafType(dataType: DataType) = dataType match {
private[sql] def getPredicateLeafType(dataType: DataType) = dataType match {
case BooleanType => PredicateLeaf.Type.BOOLEAN
case ByteType | ShortType | IntegerType | LongType => PredicateLeaf.Type.LONG
case FloatType | DoubleType => PredicateLeaf.Type.FLOAT
Expand Down Expand Up @@ -181,7 +181,7 @@ private[sql] object OrcFilters extends OrcFiltersBase {
* @return the builder so far.
*/
private def buildSearchArgument(
dataTypeMap: Map[String, DataType],
dataTypeMap: Map[String, OrcPrimitiveField],
expression: Filter,
builder: Builder): Builder = {
import org.apache.spark.sql.sources._
Expand Down Expand Up @@ -217,11 +217,11 @@ private[sql] object OrcFilters extends OrcFiltersBase {
* @return the builder so far.
*/
private def buildLeafSearchArgument(
dataTypeMap: Map[String, DataType],
dataTypeMap: Map[String, OrcPrimitiveField],
expression: Filter,
builder: Builder): Option[Builder] = {
def getType(attribute: String): PredicateLeaf.Type =
getPredicateLeafType(dataTypeMap(attribute))
getPredicateLeafType(dataTypeMap(attribute).fieldType)

import org.apache.spark.sql.sources._

Expand All @@ -231,39 +231,47 @@ private[sql] object OrcFilters extends OrcFiltersBase {
// Since ORC 1.5.0 (ORC-323), we need to quote for column names with `.` characters
// in order to distinguish predicate pushdown for nested columns.
expression match {
case EqualTo(name, value) if isSearchableType(dataTypeMap(name)) =>
val castedValue = castLiteralValue(value, dataTypeMap(name))
Some(builder.startAnd().equals(name, getType(name), castedValue).end())
case EqualTo(name, value) if dataTypeMap.contains(name) =>
val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType)
Some(builder.startAnd()
.equals(dataTypeMap(name).fieldName, getType(name), castedValue).end())

case EqualNullSafe(name, value) if isSearchableType(dataTypeMap(name)) =>
val castedValue = castLiteralValue(value, dataTypeMap(name))
Some(builder.startAnd().nullSafeEquals(name, getType(name), castedValue).end())
case EqualNullSafe(name, value) if dataTypeMap.contains(name) =>
val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType)
Some(builder.startAnd()
.nullSafeEquals(dataTypeMap(name).fieldName, getType(name), castedValue).end())

case LessThan(name, value) if isSearchableType(dataTypeMap(name)) =>
val castedValue = castLiteralValue(value, dataTypeMap(name))
Some(builder.startAnd().lessThan(name, getType(name), castedValue).end())
case LessThan(name, value) if dataTypeMap.contains(name) =>
val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType)
Some(builder.startAnd()
.lessThan(dataTypeMap(name).fieldName, getType(name), castedValue).end())

case LessThanOrEqual(name, value) if isSearchableType(dataTypeMap(name)) =>
val castedValue = castLiteralValue(value, dataTypeMap(name))
Some(builder.startAnd().lessThanEquals(name, getType(name), castedValue).end())
case LessThanOrEqual(name, value) if dataTypeMap.contains(name) =>
val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType)
Some(builder.startAnd()
.lessThanEquals(dataTypeMap(name).fieldName, getType(name), castedValue).end())

case GreaterThan(name, value) if isSearchableType(dataTypeMap(name)) =>
val castedValue = castLiteralValue(value, dataTypeMap(name))
Some(builder.startNot().lessThanEquals(name, getType(name), castedValue).end())
case GreaterThan(name, value) if dataTypeMap.contains(name) =>
val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType)
Some(builder.startNot()
.lessThanEquals(dataTypeMap(name).fieldName, getType(name), castedValue).end())

case GreaterThanOrEqual(name, value) if isSearchableType(dataTypeMap(name)) =>
val castedValue = castLiteralValue(value, dataTypeMap(name))
Some(builder.startNot().lessThan(name, getType(name), castedValue).end())
case GreaterThanOrEqual(name, value) if dataTypeMap.contains(name) =>
val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType)
Some(builder.startNot()
.lessThan(dataTypeMap(name).fieldName, getType(name), castedValue).end())

case IsNull(name) if isSearchableType(dataTypeMap(name)) =>
Some(builder.startAnd().isNull(name, getType(name)).end())
case IsNull(name) if dataTypeMap.contains(name) =>
Some(builder.startAnd()
.isNull(dataTypeMap(name).fieldName, getType(name)).end())

case IsNotNull(name) if isSearchableType(dataTypeMap(name)) =>
Some(builder.startNot().isNull(name, getType(name)).end())
case IsNotNull(name) if dataTypeMap.contains(name) =>
Some(builder.startNot()
.isNull(dataTypeMap(name).fieldName, getType(name)).end())

case In(name, values) if isSearchableType(dataTypeMap(name)) =>
val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(name)))
Some(builder.startAnd().in(name, getType(name),
case In(name, values) if dataTypeMap.contains(name) =>
val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(name).fieldType))
Some(builder.startAnd().in(dataTypeMap(name).fieldName, getType(name),
castedValues.map(_.asInstanceOf[AnyRef]): _*).end())

case _ => None
Expand Down
Loading

0 comments on commit 6c88d7c

Please sign in to comment.