Skip to content

Commit

Permalink
Unshim GpuOrcScan and GpuParquetScan [databricks] (#5012)
Browse files Browse the repository at this point in the history
* Unshim GpuParquetScan

Signed-off-by: Jason Lowe <jlowe@nvidia.com>

* Unshim GpuOrcScan

Signed-off-by: Jason Lowe <jlowe@nvidia.com>
  • Loading branch information
jlowe authored Mar 23, 2022
1 parent 6465df0 commit 5641e0a
Show file tree
Hide file tree
Showing 13 changed files with 88 additions and 175 deletions.

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package com.nvidia.spark.rapids.shims

import com.nvidia.spark.rapids.{DataFromReplacementRule, GpuOrcScanBase, RapidsConf, RapidsMeta, ScanMeta}
import com.nvidia.spark.rapids.{DataFromReplacementRule, GpuOrcScan, RapidsConf, RapidsMeta, ScanMeta}

import org.apache.spark.sql.connector.read.Scan
import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan
Expand All @@ -29,7 +29,7 @@ class RapidsOrcScanMeta(
extends ScanMeta[OrcScan](oScan, conf, parent, rule) {

override def tagSelfForGpu(): Unit = {
GpuOrcScanBase.tagSupport(this)
GpuOrcScan.tagSupport(this)
}

override def convertToGpu(): Scan =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package com.nvidia.spark.rapids.shims

import com.nvidia.spark.rapids.{DataFromReplacementRule, GpuParquetScanBase, RapidsConf, RapidsMeta, ScanMeta}
import com.nvidia.spark.rapids.{DataFromReplacementRule, GpuParquetScan, RapidsConf, RapidsMeta, ScanMeta}

import org.apache.spark.sql.connector.read.Scan
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
Expand All @@ -29,7 +29,7 @@ class RapidsParquetScanMeta(
extends ScanMeta[ParquetScan](pScan, conf, parent, rule) {

override def tagSelfForGpu(): Unit = {
GpuParquetScanBase.tagSupport(this)
GpuParquetScan.tagSupport(this)
}

override def convertToGpu(): Scan = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package com.nvidia.spark.rapids.shims

import com.nvidia.spark.rapids.{DataFromReplacementRule, GpuOrcScanBase, RapidsConf, RapidsMeta, ScanMeta}
import com.nvidia.spark.rapids.{DataFromReplacementRule, GpuOrcScan, RapidsConf, RapidsMeta, ScanMeta}

import org.apache.spark.sql.connector.read.{Scan, SupportsRuntimeFiltering}
import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan
Expand All @@ -29,7 +29,7 @@ class RapidsOrcScanMeta(
extends ScanMeta[OrcScan](oScan, conf, parent, rule) {

override def tagSelfForGpu(): Unit = {
GpuOrcScanBase.tagSupport(this)
GpuOrcScan.tagSupport(this)
// we are being overly cautious and that Orc does not support this yet
if (oScan.isInstanceOf[SupportsRuntimeFiltering]) {
willNotWorkOnGpu("Orc does not support Runtime filtering (DPP)" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package com.nvidia.spark.rapids.shims

import com.nvidia.spark.rapids.{DataFromReplacementRule, GpuParquetScanBase, RapidsConf, RapidsMeta, ScanMeta}
import com.nvidia.spark.rapids.{DataFromReplacementRule, GpuParquetScan, RapidsConf, RapidsMeta, ScanMeta}

import org.apache.spark.sql.connector.read.{Scan, SupportsRuntimeFiltering}
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
Expand All @@ -29,7 +29,7 @@ class RapidsParquetScanMeta(
extends ScanMeta[ParquetScan](pScan, conf, parent, rule) {

override def tagSelfForGpu(): Unit = {
GpuParquetScanBase.tagSupport(this)
GpuParquetScan.tagSupport(this)
// we are being overly cautious and that Parquet does not support this yet
if (pScan.isInstanceOf[SupportsRuntimeFiltering]) {
willNotWorkOnGpu("Parquet does not support Runtime filtering (DPP)" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package com.nvidia.spark.rapids.shims

import com.nvidia.spark.rapids.{DataFromReplacementRule, GpuOrcScanBase, RapidsConf, RapidsMeta, ScanMeta}
import com.nvidia.spark.rapids.{DataFromReplacementRule, GpuOrcScan, RapidsConf, RapidsMeta, ScanMeta}

import org.apache.spark.sql.connector.read.{Scan, SupportsRuntimeFiltering}
import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan
Expand All @@ -29,7 +29,7 @@ class RapidsOrcScanMeta(
extends ScanMeta[OrcScan](oScan, conf, parent, rule) {

override def tagSelfForGpu(): Unit = {
GpuOrcScanBase.tagSupport(this)
GpuOrcScan.tagSupport(this)
// we are being overly cautious and that Orc does not support this yet
if (oScan.isInstanceOf[SupportsRuntimeFiltering]) {
willNotWorkOnGpu("Orc does not support Runtime filtering (DPP)" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package com.nvidia.spark.rapids.shims

import com.nvidia.spark.rapids.{DataFromReplacementRule, GpuParquetScanBase, RapidsConf, RapidsMeta, ScanMeta}
import com.nvidia.spark.rapids.{DataFromReplacementRule, GpuParquetScan, RapidsConf, RapidsMeta, ScanMeta}

import org.apache.spark.sql.connector.read.{Scan, SupportsRuntimeFiltering}
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
Expand All @@ -29,7 +29,7 @@ class RapidsParquetScanMeta(
extends ScanMeta[ParquetScan](pScan, conf, parent, rule) {

override def tagSelfForGpu(): Unit = {
GpuParquetScanBase.tagSupport(this)
GpuParquetScan.tagSupport(this)
// we are being overly cautious and that Parquet does not support this yet
if (pScan.isInstanceOf[SupportsRuntimeFiltering]) {
willNotWorkOnGpu("Parquet does not support Runtime filtering (DPP)" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,35 +51,41 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
import org.apache.spark.sql.execution.QueryExecutionException
import org.apache.spark.sql.execution.datasources.PartitionedFile
import org.apache.spark.sql.execution.datasources.{PartitionedFile, PartitioningAwareFileIndex}
import org.apache.spark.sql.execution.datasources.orc.OrcUtils
import org.apache.spark.sql.execution.datasources.rapids.OrcFiltersWrapper
import org.apache.spark.sql.execution.datasources.v2.{EmptyPartitionReader, FilePartitionReaderFactory}
import org.apache.spark.sql.execution.datasources.v2.{EmptyPartitionReader, FilePartitionReaderFactory, FileScan}
import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids.execution.TrampolineUtil
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.{ArrayType, DataType, DecimalType, MapType, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.SerializableConfiguration

abstract class GpuOrcScanBase(
case class GpuOrcScan(
sparkSession: SparkSession,
hadoopConf: Configuration,
fileIndex: PartitioningAwareFileIndex,
dataSchema: StructType,
readDataSchema: StructType,
readPartitionSchema: StructType,
options: CaseInsensitiveStringMap,
pushedFilters: Array[Filter],
partitionFilters: Seq[Expression],
dataFilters: Seq[Expression],
rapidsConf: RapidsConf,
queryUsesInputFile: Boolean)
extends ScanWithMetrics with Logging {
queryUsesInputFile: Boolean = true)
extends ScanWithMetrics with FileScan with Logging {

def isSplitableBase(path: Path): Boolean = true
override def isSplitable(path: Path): Boolean = true

def createReaderFactoryBase(): PartitionReaderFactory = {
override def createReaderFactory(): PartitionReaderFactory = {
// Unset any serialized search argument setup by Spark's OrcScanBuilder as
// it will be incompatible due to shading and potential ORC classifier mismatch.
hadoopConf.unset(OrcConf.KRYO_SARG.getAttribute)
Expand All @@ -96,9 +102,27 @@ abstract class GpuOrcScanBase(
queryUsesInputFile)
}
}

override def equals(obj: Any): Boolean = obj match {
case o: GpuOrcScan =>
super.equals(o) && dataSchema == o.dataSchema && options == o.options &&
equivalentFilters(pushedFilters, o.pushedFilters) && rapidsConf == o.rapidsConf &&
queryUsesInputFile == o.queryUsesInputFile
case _ => false
}

override def hashCode(): Int = getClass.hashCode()

override def description(): String = {
super.description() + ", PushedFilters: " + seqToString(pushedFilters)
}

def withFilters(
partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan =
this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters)
}

object GpuOrcScanBase {
object GpuOrcScan {
def tagSupport(scanMeta: ScanMeta[OrcScan]): Unit = {
val scan = scanMeta.wrapped
val schema = StructType(scan.readDataSchema ++ scan.readPartitionSchema)
Expand Down
Loading

0 comments on commit 5641e0a

Please sign in to comment.