Skip to content

Commit

Permalink
Pass metadata extractors to FileScanRDD [databricks] (#10616)
Browse files Browse the repository at this point in the history
* Pass metadata extractors to FileScanRDD

* Signing off

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* addressed review comments

* updated copyrights manually

---------

Signed-off-by: Raza Jafri <rjafri@nvidia.com>
  • Loading branch information
razajafri authored Mar 22, 2024
1 parent 579cc13 commit 09a0081
Show file tree
Hide file tree
Showing 7 changed files with 31 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.sql.connector.read.Scan
import org.apache.spark.sql.execution.{ColumnarToRowTransition, SparkPlan}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.command.{DataWritingCommand, RunnableCommand}
import org.apache.spark.sql.execution.datasources.{FilePartition, PartitionedFile, PartitioningAwareFileIndex}
import org.apache.spark.sql.execution.datasources.{FileFormat, FilePartition, PartitionedFile, PartitioningAwareFileIndex}
import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters
import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchangeLike}
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -78,7 +78,8 @@ trait SparkShims {
readFunction: (PartitionedFile) => Iterator[InternalRow],
filePartitions: Seq[FilePartition],
readDataSchema: StructType,
metadataColumns: Seq[AttributeReference] = Seq.empty): RDD[InternalRow]
metadataColumns: Seq[AttributeReference] = Seq.empty,
fileFormat: Option[FileFormat] = None): RDD[InternalRow]

def shouldFailDivOverflow: Boolean

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@ case class GpuFileSourceScanExec(
if (isPerFileReadEnabled) {
logInfo("Using the original per file reader")
SparkShimImpl.getFileScanRDD(relation.sparkSession, readFile.get, locatedPartitions,
requiredSchema)
requiredSchema, fileFormat = Some(relation.fileFormat))
} else {
logDebug(s"Using Datasource RDD, files are: " +
s"${prunedPartitions.flatMap(_.files).mkString(",")}")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
* Copyright (c) 2021-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, DataWritingCommand, RunnableCommand}
import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD, PartitionedFile}
import org.apache.spark.sql.execution.datasources.{FileFormat, FilePartition, FileScanRDD, PartitionedFile}
import org.apache.spark.sql.execution.datasources.v2._
import org.apache.spark.sql.types.StructType

Expand All @@ -50,7 +50,8 @@ trait Spark31Xuntil33XShims extends SparkShims {
readFunction: PartitionedFile => Iterator[InternalRow],
filePartitions: Seq[FilePartition],
readDataSchema: StructType,
metadataColumns: Seq[AttributeReference]): RDD[InternalRow] = {
metadataColumns: Seq[AttributeReference],
fileFormat: Option[FileFormat]): RDD[InternalRow] = {
new FileScanRDD(sparkSession, readFunction, filePartitions)
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
* Copyright (c) 2021-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -37,7 +37,7 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD, PartitionedFile}
import org.apache.spark.sql.execution.datasources.{FileFormat, FilePartition, FileScanRDD, PartitionedFile}
import org.apache.spark.sql.rapids.shims.{GpuDivideYMInterval, GpuMultiplyYMInterval}
import org.apache.spark.sql.types.StructType

Expand All @@ -50,7 +50,8 @@ trait Spark330PlusShims extends Spark321PlusShims with Spark320PlusNonDBShims {
readFunction: PartitionedFile => Iterator[InternalRow],
filePartitions: Seq[FilePartition],
readDataSchema: StructType,
metadataColumns: Seq[AttributeReference]): RDD[InternalRow] = {
metadataColumns: Seq[AttributeReference],
fileFormat: Option[FileFormat]): RDD[InternalRow] = {
new FileScanRDD(sparkSession, readFunction, filePartitions, readDataSchema, metadataColumns)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ trait Spark321PlusDBShims extends SparkShims
readFunction: PartitionedFile => Iterator[InternalRow],
filePartitions: Seq[FilePartition],
readDataSchema: StructType,
metadataColumns: Seq[AttributeReference]): RDD[InternalRow] = {
metadataColumns: Seq[AttributeReference],
fileFormat: Option[FileFormat]): RDD[InternalRow] = {
new GpuFileScanRDD(sparkSession, readFunction, filePartitions)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,29 @@ package com.nvidia.spark.rapids.shims

import com.nvidia.spark.rapids._

import org.apache.spark.sql.catalyst.expressions.{Expression, PythonUDAF, ToPrettyString}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, PythonUDAF, ToPrettyString}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.adaptive.TableCacheQueryStageExec
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.datasources.{FileFormat, FilePartition, FileScanRDD, PartitionedFile}
import org.apache.spark.sql.execution.window.WindowGroupLimitExec
import org.apache.spark.sql.rapids.execution.python.GpuPythonUDAF
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.types.{StringType, StructType}

object SparkShimImpl extends Spark340PlusNonDBShims {
override def getFileScanRDD(
sparkSession: SparkSession,
readFunction: PartitionedFile => Iterator[InternalRow],
filePartitions: Seq[FilePartition],
readDataSchema: StructType,
metadataColumns: Seq[AttributeReference] = Seq.empty,
fileFormat: Option[FileFormat]): RDD[InternalRow] = {
new FileScanRDD(sparkSession, readFunction, filePartitions, readDataSchema, metadataColumns,
metadataExtractors = fileFormat.map(_.fileConstantMetadataExtractors).getOrElse(Map.empty))
}

override def getExprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = {
val shimExprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = Seq(
Expand Down

0 comments on commit 09a0081

Please sign in to comment.