Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle readBatch changes for Spark 3.3.0 #5425

Merged
merged 9 commits into from
May 9, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.datasources.parquet

import java.io.IOException
import java.lang.reflect.Method

import com.nvidia.spark.CurrentBatchIterator
import com.nvidia.spark.rapids.ParquetCachedBatch
import java.util
import org.apache.hadoop.conf.Configuration
import org.apache.parquet.ParquetReadOptions
import org.apache.parquet.column.ColumnDescriptor
import org.apache.parquet.schema.Type

import org.apache.spark.TaskContext
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.execution.datasources.parquet.rapids.shims.ShimVectorizedColumnReader
import org.apache.spark.sql.execution.vectorized.WritableColumnVector
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy

object ParquetVectorizedReader {
private var readBatchMethod: Method = null
def getReadBatchMethod(): Method = {
if (readBatchMethod == null) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider lazy val instead of manual memoization

readBatchMethod =
classOf[VectorizedColumnReader].getDeclaredMethod("readBatch", Integer.TYPE,
classOf[WritableColumnVector])
readBatchMethod.setAccessible(true)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider org.apache.commons.lang3.reflect.MethodUtils for such tasks

Copy link
Collaborator Author

@razajafri razajafri May 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the advantage of using MethodUtils?
It doesn't reduce the number of lines of code

val method = MethodUtils.getMatchingMethod(classOf[VectorizedColumnReader], "readBatch", Integer.TYPE,  classOf[WritableColumnVector])
method.setAccessible(true)

Nor do we run in JRE prior to 1.4 where MethodUtils has a work around.

IMO this would bring a 3rd party in to the picture when it's not really adding any value. Please feel free to elaborate on why you made the suggestion.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not a big deal but there are invokeMethod flavors where you can just pass forceAccess = true

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That will look up the method every time as seen here. I don't think that's what we want

}
readBatchMethod
}
}

/**
* This class takes a lot of the logic from
* org.apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java
*/
class ShimCurrentBatchIterator(
parquetCachedBatch: ParquetCachedBatch,
conf: SQLConf,
selectedAttributes: Seq[Attribute],
options: ParquetReadOptions,
hadoopConf: Configuration) extends CurrentBatchIterator(
parquetCachedBatch,
conf,
selectedAttributes,
options,
hadoopConf) {

var columnReaders: Array[VectorizedColumnReader] = _
val missingColumns = new Array[Boolean](reqParquetSchemaInCacheOrder.getFieldCount)
val typesInCache: util.List[Type] = reqParquetSchemaInCacheOrder.asGroupType.getFields
val columnsInCache: util.List[ColumnDescriptor] = reqParquetSchemaInCacheOrder.getColumns
val columnsRequested: util.List[ColumnDescriptor] = reqParquetSchemaInCacheOrder.getColumns

// initialize missingColumns to cover the case where requested column isn't present in the
// cache, which should never happen but just in case it does
val paths: util.List[Array[String]] = reqParquetSchemaInCacheOrder.getPaths

for (i <- 0 until reqParquetSchemaInCacheOrder.getFieldCount) {
val t = reqParquetSchemaInCacheOrder.getFields.get(i)
if (!t.isPrimitive || t.isRepetition(Type.Repetition.REPEATED)) {
throw new UnsupportedOperationException("Complex types not supported.")
}
val colPath = paths.get(i)
if (inMemCacheParquetSchema.containsPath(colPath)) {
val fd = inMemCacheParquetSchema.getColumnDescription(colPath)
if (!(fd == columnsRequested.get(i))) {
throw new UnsupportedOperationException("Schema evolution not supported.")
}
missingColumns(i) = false
} else {
if (columnsRequested.get(i).getMaxDefinitionLevel == 0) {
// Column is missing in data but the required data is non-nullable.
// This file is invalid.
throw new IOException(s"Required column is missing in data file: ${colPath.toList}")
}
missingColumns(i) = true
}
}

for (i <- missingColumns.indices) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like we could save this extra loop if we moved L100-101 to after 94 missingColumns(i) = true

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right

if (missingColumns(i)) {
vectors(i).putNulls(0, capacity)
vectors(i).setIsConstant()
}
}

@throws[IOException]
def checkEndOfRowGroup(): Unit = {
if (rowsReturned != totalCountLoadedSoFar) return
val pages = parquetFileReader.readNextRowGroup
if (pages == null) {
throw new IOException("expecting more rows but reached last" +
" block. Read " + rowsReturned + " out of " + totalRowCount)
}
columnReaders = new Array[VectorizedColumnReader](columnsRequested.size)
for (i <- 0 until columnsRequested.size) {
if (!missingColumns(i)) {
columnReaders(i) =
new ShimVectorizedColumnReader(
i,
columnsInCache,
typesInCache,
pages,
convertTz = null,
LegacyBehaviorPolicy.CORRECTED.toString,
LegacyBehaviorPolicy.EXCEPTION.toString,
int96CDPHive3Compatibility = false,
writerVersion)
}
}
totalCountLoadedSoFar += pages.getRowCount
}

/**
* Read the next RowGroup and read each column and return the columnarBatch
*/
def nextBatch: Boolean = {
for (vector <- vectors) {
vector.reset()
}
columnarBatch.setNumRows(0)
if (rowsReturned >= totalRowCount) return false
checkEndOfRowGroup()
val num = Math.min(capacity.toLong, totalCountLoadedSoFar - rowsReturned).toInt
for (i <- columnReaders.indices) {
if (columnReaders(i) != null) {
ParquetVectorizedReader.getReadBatchMethod()
.invoke(columnReaders(i), num.asInstanceOf[AnyRef],
vectors(cacheSchemaToReqSchemaMap(i)).asInstanceOf[AnyRef])
}
}
rowsReturned += num
columnarBatch.setNumRows(num)
true
}

override def hasNext: Boolean = rowsReturned < totalRowCount

TaskContext.get().addTaskCompletionListener[Unit]((_: TaskContext) => {
close()
})

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.datasources.parquet

import java.io.IOException
import java.lang.reflect.Method

import scala.collection.JavaConverters._

import com.nvidia.spark.CurrentBatchIterator
import com.nvidia.spark.rapids.ParquetCachedBatch
import java.util
import org.apache.hadoop.conf.Configuration
import org.apache.parquet.ParquetReadOptions
import org.apache.parquet.column.page.PageReadStore
import org.apache.parquet.schema.{GroupType, Type}

import org.apache.spark.memory.MemoryMode
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.execution.vectorized.WritableColumnVector
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy
import org.apache.spark.sql.types.StructType

object ParquetVectorizedReader {
/**
* We are getting this method using reflection because its a package-private
*/
private var readBatchMethod: Method = null

def getReadBatchMethod(): Method = {
if (readBatchMethod == null) {
readBatchMethod =
classOf[VectorizedColumnReader].getDeclaredMethod("readBatch", Integer.TYPE,
classOf[WritableColumnVector],
classOf[WritableColumnVector],
classOf[WritableColumnVector])
readBatchMethod.setAccessible(true)
}
readBatchMethod
}
}

/**
* This class takes a lot of the logic from
* org.apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java
*/
class ShimCurrentBatchIterator(
parquetCachedBatch: ParquetCachedBatch,
conf: SQLConf,
selectedAttributes: Seq[Attribute],
options: ParquetReadOptions,
hadoopConf: Configuration)
extends CurrentBatchIterator(
parquetCachedBatch,
conf,
selectedAttributes,
options,
hadoopConf) {

val missingColumns: util.Set[ParquetColumn] = new util.HashSet[ParquetColumn]()
val config = new Configuration
config.setBoolean(SQLConf.PARQUET_BINARY_AS_STRING.key, false)
config.setBoolean(SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, false)
config.setBoolean(SQLConf.CASE_SENSITIVE.key, false)
val parquetColumn = new ParquetToSparkSchemaConverter(config)
.convertParquetColumn(reqParquetSchemaInCacheOrder, Option.empty)

// initialize missingColumns to cover the case where requested column isn't present in the
// cache, which should never happen but just in case it does
for (column <- parquetColumn.children) {
checkColumn(column)
}

val sparkSchema = parquetColumn.sparkType.asInstanceOf[StructType]
val parquetColumnVectors = (for (i <- 0 until sparkSchema.fields.length) yield {
new ParquetColumnVector(parquetColumn.children.apply(i),
vectors(i), capacity, MemoryMode.OFF_HEAP, missingColumns)
}).toArray

private def containsPath(parquetType: Type, path: Array[String]): Boolean =
containsPath(parquetType, path, 0)

private def containsPath(parquetType: Type, path: Array[String], depth: Int): Boolean = {
if (path.length == depth) return true
if (parquetType.isInstanceOf[GroupType]) {
val fieldName = path(depth)
val parquetGroupType = parquetType.asInstanceOf[GroupType]
if (parquetGroupType.containsField(fieldName)) {
return containsPath(parquetGroupType.getType(fieldName), path, depth + 1)
}
}
false
}

private def checkColumn(column: ParquetColumn): Unit = {
val paths = column.path.toArray
if (containsPath(inMemCacheParquetSchema, paths)) {
if (column.isPrimitive) {
val desc = column.descriptor.get
val fd = inMemCacheParquetSchema.getColumnDescription(desc.getPath)
if (!fd.equals(desc)) {
throw new UnsupportedOperationException("Complex types not supported.")
}
} else {
for (childColumn <- column.children) {
checkColumn(childColumn)
}
}
} else {
if (column.required) {
if (column.required) {
// Column is missing in data but the required data is non-nullable. This file is invalid.
throw new IOException("Required column is missing in data file. Col: " + paths)
}
missingColumns.add(column);
}
}

}

@throws[IOException]
private def initColumnReader(pages: PageReadStore, cv: ParquetColumnVector): Unit = {
if (!missingColumns.contains(cv.getColumn)) {
if (cv.getColumn.isPrimitive) {
val column = cv.getColumn
val reader = new VectorizedColumnReader(
column.descriptor.get,
column.required,
pages,
null,
LegacyBehaviorPolicy.CORRECTED.toString,
LegacyBehaviorPolicy.EXCEPTION.toString,
LegacyBehaviorPolicy.EXCEPTION.toString,
null,
writerVersion)
cv.setColumnReader(reader)
}
else { // Not in missing columns and is a complex type: this must be a struct
for (childCv <- cv.getChildren.asScala) {
initColumnReader(pages, childCv)
}
}
}
}

@throws[IOException]
def checkEndOfRowGroup(): Unit = {
if (rowsReturned != totalCountLoadedSoFar) return
val pages = parquetFileReader.readNextRowGroup
if (pages == null) {
throw new IOException("expecting more rows but reached last" +
" block. Read " + rowsReturned + " out of " + totalRowCount)
}
for (cv <- parquetColumnVectors) {
initColumnReader(pages, cv)
}
totalCountLoadedSoFar += pages.getRowCount
}

/**
* Read the next RowGroup and read each column and return the columnarBatch
*/
def nextBatch: Boolean = {
for (vector <- parquetColumnVectors) {
vector.reset()
}
columnarBatch.setNumRows(0)
if (rowsReturned >= totalRowCount) return false
checkEndOfRowGroup()

val num = Math.min(capacity.toLong, totalCountLoadedSoFar - rowsReturned).toInt
for (cv <- parquetColumnVectors){
for (leafCv <- cv.getLeaves.asScala) {
val columnReader = leafCv.getColumnReader
if (columnReader != null) {
ParquetVectorizedReader.getReadBatchMethod.invoke(
columnReader,
num.asInstanceOf[AnyRef],
leafCv.getValueVector.asInstanceOf[AnyRef],
leafCv.getRepetitionLevelVector.asInstanceOf[AnyRef],
leafCv.getDefinitionLevelVector.asInstanceOf[AnyRef])
}
}
cv.assemble()
}
rowsReturned += num
columnarBatch.setNumRows(num)
true
}
}
nartal1 marked this conversation as resolved.
Show resolved Hide resolved
Loading