-
Notifications
You must be signed in to change notification settings - Fork 230
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Handle readBatch changes for Spark 3.3.0 #5425
Changes from 4 commits
25d6e47
a7aca66
fdaf82d
dcdb595
2883ce6
732341e
ca9f78b
df2b0f6
a1a9bd9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
/* | ||
* Copyright (c) 2022, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.sql.execution.datasources.parquet | ||
|
||
import java.io.IOException | ||
import java.lang.reflect.Method | ||
|
||
import com.nvidia.spark.CurrentBatchIterator | ||
import com.nvidia.spark.rapids.ParquetCachedBatch | ||
import java.util | ||
import org.apache.hadoop.conf.Configuration | ||
import org.apache.parquet.ParquetReadOptions | ||
import org.apache.parquet.column.ColumnDescriptor | ||
import org.apache.parquet.schema.Type | ||
|
||
import org.apache.spark.TaskContext | ||
import org.apache.spark.sql.catalyst.expressions.Attribute | ||
import org.apache.spark.sql.execution.datasources.parquet.rapids.shims.ShimVectorizedColumnReader | ||
import org.apache.spark.sql.execution.vectorized.WritableColumnVector | ||
import org.apache.spark.sql.internal.SQLConf | ||
import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy | ||
|
||
object ParquetVectorizedReader { | ||
private var readBatchMethod: Method = null | ||
def getReadBatchMethod(): Method = { | ||
if (readBatchMethod == null) { | ||
readBatchMethod = | ||
classOf[VectorizedColumnReader].getDeclaredMethod("readBatch", Integer.TYPE, | ||
classOf[WritableColumnVector]) | ||
readBatchMethod.setAccessible(true) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the advantage of using
Nor do we run in JRE prior to 1.4 where IMO this would bring a 3rd party in to the picture when it's not really adding any value. Please feel free to elaborate on why you made the suggestion. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not a big deal but there are invokeMethod flavors where you can just pass forceAccess = true There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That will look up the method every time as seen here. I don't think that's what we want |
||
} | ||
readBatchMethod | ||
} | ||
} | ||
|
||
/** | ||
* This class takes a lot of the logic from | ||
* org.apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java | ||
*/ | ||
class ShimCurrentBatchIterator( | ||
parquetCachedBatch: ParquetCachedBatch, | ||
conf: SQLConf, | ||
selectedAttributes: Seq[Attribute], | ||
options: ParquetReadOptions, | ||
hadoopConf: Configuration) extends CurrentBatchIterator( | ||
parquetCachedBatch, | ||
conf, | ||
selectedAttributes, | ||
options, | ||
hadoopConf) { | ||
|
||
var columnReaders: Array[VectorizedColumnReader] = _ | ||
val missingColumns = new Array[Boolean](reqParquetSchemaInCacheOrder.getFieldCount) | ||
val typesInCache: util.List[Type] = reqParquetSchemaInCacheOrder.asGroupType.getFields | ||
val columnsInCache: util.List[ColumnDescriptor] = reqParquetSchemaInCacheOrder.getColumns | ||
val columnsRequested: util.List[ColumnDescriptor] = reqParquetSchemaInCacheOrder.getColumns | ||
|
||
// initialize missingColumns to cover the case where requested column isn't present in the | ||
// cache, which should never happen but just in case it does | ||
val paths: util.List[Array[String]] = reqParquetSchemaInCacheOrder.getPaths | ||
|
||
for (i <- 0 until reqParquetSchemaInCacheOrder.getFieldCount) { | ||
val t = reqParquetSchemaInCacheOrder.getFields.get(i) | ||
if (!t.isPrimitive || t.isRepetition(Type.Repetition.REPEATED)) { | ||
throw new UnsupportedOperationException("Complex types not supported.") | ||
} | ||
val colPath = paths.get(i) | ||
if (inMemCacheParquetSchema.containsPath(colPath)) { | ||
val fd = inMemCacheParquetSchema.getColumnDescription(colPath) | ||
if (!(fd == columnsRequested.get(i))) { | ||
throw new UnsupportedOperationException("Schema evolution not supported.") | ||
} | ||
missingColumns(i) = false | ||
} else { | ||
if (columnsRequested.get(i).getMaxDefinitionLevel == 0) { | ||
// Column is missing in data but the required data is non-nullable. | ||
// This file is invalid. | ||
throw new IOException(s"Required column is missing in data file: ${colPath.toList}") | ||
} | ||
missingColumns(i) = true | ||
} | ||
} | ||
|
||
for (i <- missingColumns.indices) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. seems like we could save this extra loop if we moved L100-101 to after 94 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You are right |
||
if (missingColumns(i)) { | ||
vectors(i).putNulls(0, capacity) | ||
vectors(i).setIsConstant() | ||
} | ||
} | ||
|
||
@throws[IOException] | ||
def checkEndOfRowGroup(): Unit = { | ||
if (rowsReturned != totalCountLoadedSoFar) return | ||
val pages = parquetFileReader.readNextRowGroup | ||
if (pages == null) { | ||
throw new IOException("expecting more rows but reached last" + | ||
" block. Read " + rowsReturned + " out of " + totalRowCount) | ||
} | ||
columnReaders = new Array[VectorizedColumnReader](columnsRequested.size) | ||
for (i <- 0 until columnsRequested.size) { | ||
if (!missingColumns(i)) { | ||
columnReaders(i) = | ||
new ShimVectorizedColumnReader( | ||
i, | ||
columnsInCache, | ||
typesInCache, | ||
pages, | ||
convertTz = null, | ||
LegacyBehaviorPolicy.CORRECTED.toString, | ||
LegacyBehaviorPolicy.EXCEPTION.toString, | ||
int96CDPHive3Compatibility = false, | ||
writerVersion) | ||
} | ||
} | ||
totalCountLoadedSoFar += pages.getRowCount | ||
} | ||
|
||
/** | ||
* Read the next RowGroup and read each column and return the columnarBatch | ||
*/ | ||
def nextBatch: Boolean = { | ||
for (vector <- vectors) { | ||
vector.reset() | ||
} | ||
columnarBatch.setNumRows(0) | ||
if (rowsReturned >= totalRowCount) return false | ||
checkEndOfRowGroup() | ||
val num = Math.min(capacity.toLong, totalCountLoadedSoFar - rowsReturned).toInt | ||
for (i <- columnReaders.indices) { | ||
if (columnReaders(i) != null) { | ||
ParquetVectorizedReader.getReadBatchMethod() | ||
.invoke(columnReaders(i), num.asInstanceOf[AnyRef], | ||
vectors(cacheSchemaToReqSchemaMap(i)).asInstanceOf[AnyRef]) | ||
} | ||
} | ||
rowsReturned += num | ||
columnarBatch.setNumRows(num) | ||
true | ||
} | ||
|
||
override def hasNext: Boolean = rowsReturned < totalRowCount | ||
|
||
TaskContext.get().addTaskCompletionListener[Unit]((_: TaskContext) => { | ||
close() | ||
}) | ||
|
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,205 @@ | ||
/* | ||
* Copyright (c) 2022, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.sql.execution.datasources.parquet | ||
|
||
import java.io.IOException | ||
import java.lang.reflect.Method | ||
|
||
import scala.collection.JavaConverters._ | ||
|
||
import com.nvidia.spark.CurrentBatchIterator | ||
import com.nvidia.spark.rapids.ParquetCachedBatch | ||
import java.util | ||
import org.apache.hadoop.conf.Configuration | ||
import org.apache.parquet.ParquetReadOptions | ||
import org.apache.parquet.column.page.PageReadStore | ||
import org.apache.parquet.schema.{GroupType, Type} | ||
|
||
import org.apache.spark.memory.MemoryMode | ||
import org.apache.spark.sql.catalyst.expressions.Attribute | ||
import org.apache.spark.sql.execution.vectorized.WritableColumnVector | ||
import org.apache.spark.sql.internal.SQLConf | ||
import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy | ||
import org.apache.spark.sql.types.StructType | ||
|
||
object ParquetVectorizedReader { | ||
/** | ||
* We are getting this method using reflection because its a package-private | ||
*/ | ||
private var readBatchMethod: Method = null | ||
|
||
def getReadBatchMethod(): Method = { | ||
if (readBatchMethod == null) { | ||
readBatchMethod = | ||
classOf[VectorizedColumnReader].getDeclaredMethod("readBatch", Integer.TYPE, | ||
classOf[WritableColumnVector], | ||
classOf[WritableColumnVector], | ||
classOf[WritableColumnVector]) | ||
readBatchMethod.setAccessible(true) | ||
} | ||
readBatchMethod | ||
} | ||
} | ||
|
||
/** | ||
* This class takes a lot of the logic from | ||
* org.apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java | ||
*/ | ||
class ShimCurrentBatchIterator( | ||
parquetCachedBatch: ParquetCachedBatch, | ||
conf: SQLConf, | ||
selectedAttributes: Seq[Attribute], | ||
options: ParquetReadOptions, | ||
hadoopConf: Configuration) | ||
extends CurrentBatchIterator( | ||
parquetCachedBatch, | ||
conf, | ||
selectedAttributes, | ||
options, | ||
hadoopConf) { | ||
|
||
val missingColumns: util.Set[ParquetColumn] = new util.HashSet[ParquetColumn]() | ||
val config = new Configuration | ||
config.setBoolean(SQLConf.PARQUET_BINARY_AS_STRING.key, false) | ||
config.setBoolean(SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, false) | ||
config.setBoolean(SQLConf.CASE_SENSITIVE.key, false) | ||
val parquetColumn = new ParquetToSparkSchemaConverter(config) | ||
.convertParquetColumn(reqParquetSchemaInCacheOrder, Option.empty) | ||
|
||
// initialize missingColumns to cover the case where requested column isn't present in the | ||
// cache, which should never happen but just in case it does | ||
for (column <- parquetColumn.children) { | ||
checkColumn(column) | ||
} | ||
|
||
val sparkSchema = parquetColumn.sparkType.asInstanceOf[StructType] | ||
val parquetColumnVectors = (for (i <- 0 until sparkSchema.fields.length) yield { | ||
new ParquetColumnVector(parquetColumn.children.apply(i), | ||
vectors(i), capacity, MemoryMode.OFF_HEAP, missingColumns) | ||
}).toArray | ||
|
||
private def containsPath(parquetType: Type, path: Array[String]): Boolean = | ||
containsPath(parquetType, path, 0) | ||
|
||
private def containsPath(parquetType: Type, path: Array[String], depth: Int): Boolean = { | ||
if (path.length == depth) return true | ||
if (parquetType.isInstanceOf[GroupType]) { | ||
val fieldName = path(depth) | ||
val parquetGroupType = parquetType.asInstanceOf[GroupType] | ||
if (parquetGroupType.containsField(fieldName)) { | ||
return containsPath(parquetGroupType.getType(fieldName), path, depth + 1) | ||
} | ||
} | ||
false | ||
} | ||
|
||
private def checkColumn(column: ParquetColumn): Unit = { | ||
val paths = column.path.toArray | ||
if (containsPath(inMemCacheParquetSchema, paths)) { | ||
if (column.isPrimitive) { | ||
val desc = column.descriptor.get | ||
val fd = inMemCacheParquetSchema.getColumnDescription(desc.getPath) | ||
if (!fd.equals(desc)) { | ||
throw new UnsupportedOperationException("Complex types not supported.") | ||
} | ||
} else { | ||
for (childColumn <- column.children) { | ||
checkColumn(childColumn) | ||
} | ||
} | ||
} else { | ||
if (column.required) { | ||
if (column.required) { | ||
// Column is missing in data but the required data is non-nullable. This file is invalid. | ||
throw new IOException("Required column is missing in data file. Col: " + paths) | ||
} | ||
missingColumns.add(column); | ||
} | ||
} | ||
|
||
} | ||
|
||
@throws[IOException] | ||
private def initColumnReader(pages: PageReadStore, cv: ParquetColumnVector): Unit = { | ||
if (!missingColumns.contains(cv.getColumn)) { | ||
if (cv.getColumn.isPrimitive) { | ||
val column = cv.getColumn | ||
val reader = new VectorizedColumnReader( | ||
column.descriptor.get, | ||
column.required, | ||
pages, | ||
null, | ||
LegacyBehaviorPolicy.CORRECTED.toString, | ||
LegacyBehaviorPolicy.EXCEPTION.toString, | ||
LegacyBehaviorPolicy.EXCEPTION.toString, | ||
null, | ||
writerVersion) | ||
cv.setColumnReader(reader) | ||
} | ||
else { // Not in missing columns and is a complex type: this must be a struct | ||
for (childCv <- cv.getChildren.asScala) { | ||
initColumnReader(pages, childCv) | ||
} | ||
} | ||
} | ||
} | ||
|
||
@throws[IOException] | ||
def checkEndOfRowGroup(): Unit = { | ||
if (rowsReturned != totalCountLoadedSoFar) return | ||
val pages = parquetFileReader.readNextRowGroup | ||
if (pages == null) { | ||
throw new IOException("expecting more rows but reached last" + | ||
" block. Read " + rowsReturned + " out of " + totalRowCount) | ||
} | ||
for (cv <- parquetColumnVectors) { | ||
initColumnReader(pages, cv) | ||
} | ||
totalCountLoadedSoFar += pages.getRowCount | ||
} | ||
|
||
/** | ||
* Read the next RowGroup and read each column and return the columnarBatch | ||
*/ | ||
def nextBatch: Boolean = { | ||
for (vector <- parquetColumnVectors) { | ||
vector.reset() | ||
} | ||
columnarBatch.setNumRows(0) | ||
if (rowsReturned >= totalRowCount) return false | ||
checkEndOfRowGroup() | ||
|
||
val num = Math.min(capacity.toLong, totalCountLoadedSoFar - rowsReturned).toInt | ||
for (cv <- parquetColumnVectors){ | ||
for (leafCv <- cv.getLeaves.asScala) { | ||
val columnReader = leafCv.getColumnReader | ||
if (columnReader != null) { | ||
ParquetVectorizedReader.getReadBatchMethod.invoke( | ||
columnReader, | ||
num.asInstanceOf[AnyRef], | ||
leafCv.getValueVector.asInstanceOf[AnyRef], | ||
leafCv.getRepetitionLevelVector.asInstanceOf[AnyRef], | ||
leafCv.getDefinitionLevelVector.asInstanceOf[AnyRef]) | ||
} | ||
} | ||
cv.assemble() | ||
} | ||
rowsReturned += num | ||
columnarBatch.setNumRows(num) | ||
true | ||
} | ||
} | ||
nartal1 marked this conversation as resolved.
Show resolved
Hide resolved
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
consider
lazy val
instead of manual memoization