Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix infinite loop in MultiFileCloudPartitionReaderBase #2873

Merged
merged 3 commits into from
Jul 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,8 @@ abstract class MultiFileCloudPartitionReaderBase(
private var isInitted = false
private val tasks = new ConcurrentLinkedQueue[Future[HostMemoryBuffersWithMetaDataBase]]()
private val tasksToRun = new Queue[Callable[HostMemoryBuffersWithMetaDataBase]]()
private[this] val inputMetrics = TaskContext.get.taskMetrics().inputMetrics
private[this] val inputMetrics = Option(TaskContext.get).map(_.taskMetrics().inputMetrics)
.getOrElse(TrampolineUtil.newInputMetrics())

private def initAndStartReaders(): Unit = {
// limit the number we submit at once according to the config if set
Expand Down Expand Up @@ -382,11 +383,12 @@ abstract class MultiFileCloudPartitionReaderBase(
// if we have batch left from the last file read return it
if (currentFileHostBuffers.isDefined) {
if (getSizeOfHostBuffers(currentFileHostBuffers.get) == 0) {
closeCurrentFileHostBuffers()
jlowe marked this conversation as resolved.
Show resolved Hide resolved
next()
} else {
batch = readBatch(currentFileHostBuffers.get)
}
batch = readBatch(currentFileHostBuffers.get)
} else {
currentFileHostBuffers = None
if (filesToRead > 0 && !isDone) {
val fileBufsAndMeta = tasks.poll.get()
filesToRead -= 1
Expand All @@ -399,6 +401,7 @@ abstract class MultiFileCloudPartitionReaderBase(
if (getSizeOfHostBuffers(fileBufsAndMeta) == 0) {
// if sizes are 0 means no rows and no data so skip to next file
// file data was empty so submit another task if any were waiting
closeCurrentFileHostBuffers()
addNextTaskIfNeeded()
next()
} else {
Expand Down Expand Up @@ -436,10 +439,7 @@ abstract class MultiFileCloudPartitionReaderBase(
}
}

override def close(): Unit = {
// this is more complicated because threads might still be processing files
// in cases close got called early for like limit() calls
isDone = true
private def closeCurrentFileHostBuffers(): Unit = {
currentFileHostBuffers.foreach { current =>
current.memBuffersAndSizes.foreach { case (buf, _) =>
if (buf != null) {
Expand All @@ -448,6 +448,13 @@ abstract class MultiFileCloudPartitionReaderBase(
}
}
currentFileHostBuffers = None
}

override def close(): Unit = {
// this is more complicated because threads might still be processing files
// in cases close got called early for like limit() calls
isDone = true
closeCurrentFileHostBuffers()
batch.foreach(_.close())
batch = None
tasks.asScala.foreach { task =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ object TrampolineUtil {

def asNullable(dt: DataType): DataType = dt.asNullable

/** Return a new InputMetrics instance */
def newInputMetrics(): InputMetrics = new InputMetrics()

/**
* Increment the task's memory bytes spilled metric. If the current thread does not
* correspond to a Spark task then this call does nothing.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids

import java.util.concurrent.{Callable, ThreadPoolExecutor}

import ai.rapids.cudf.HostMemoryBuffer
import org.apache.hadoop.conf.Configuration
import org.scalatest.FunSuite

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.datasources.PartitionedFile
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.vectorized.ColumnarBatch

class GpuMultiFileReaderSuite extends FunSuite with Arm {
test("avoid infinite loop when host buffers empty") {
val conf = new Configuration(false)
val membuffers = Array((HostMemoryBuffer.allocate(0), 0L))
val multiFileReader = new MultiFileCloudPartitionReaderBase(
conf,
files = Array.empty,
numThreads = 1,
maxNumFileProcessed = 1,
filters = Array.empty,
execMetrics = Map(GpuMetric.PEAK_DEVICE_MEMORY -> NoopMetric)) {

// Setup some empty host buffers at the start
currentFileHostBuffers = Some(new HostMemoryBuffersWithMetaDataBase {
override def partitionedFile: PartitionedFile =
PartitionedFile(InternalRow.empty, "", 0, 0)
override def memBuffersAndSizes: Array[(HostMemoryBuffer, Long)] = membuffers
override def bytesRead: Long = 0
})

override def getBatchRunner(
file: PartitionedFile,
conf: Configuration,
filters: Array[Filter]): Callable[HostMemoryBuffersWithMetaDataBase] = {
() => null
}

override def getThreadPool(numThreads: Int): ThreadPoolExecutor =
MultiFileThreadPoolUtil.createThreadPool("testpool")

override def readBatch(h: HostMemoryBuffersWithMetaDataBase): Option[ColumnarBatch] = None

override def getFileFormatShortName: String = ""
}

withResource(multiFileReader) { _ =>
assertResult(false)(multiFileReader.next())
}
}
}