Skip to content

Commit

Permalink
Fix some resource leaks in GpuCast and RapidsShuffleServerSuite (NVID…
Browse files Browse the repository at this point in the history
…IA#3231)

Signed-off-by: Jason Lowe <jlowe@nvidia.com>
Signed-off-by: Raza Jafri <rjafri@nvidia.com>
  • Loading branch information
jlowe authored and razajafri committed Aug 23, 2021
1 parent a1140b0 commit 03aab1f
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 37 deletions.
26 changes: 13 additions & 13 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -243,21 +243,21 @@ object GpuCast extends Arm {
}
}

if (ansiEnabled) {
// ansi mode only supports simple integers, so no exponents or decimal places
val regex = "^[+\\-]?[0-9]+$"
withResource(sanitized.matchesRe(regex)) { isInt =>
withResource(isInt.all()) { allInts =>
// Check that all non-null values are valid integers.
if (allInts.isValid && !allInts.getBoolean) {
throw new NumberFormatException(GpuCast.INVALID_INPUT_MESSAGE)
withResource(sanitized) { _ =>
if (ansiEnabled) {
// ansi mode only supports simple integers, so no exponents or decimal places
val regex = "^[+\\-]?[0-9]+$"
withResource(sanitized.matchesRe(regex)) { isInt =>
withResource(isInt.all()) { allInts =>
// Check that all non-null values are valid integers.
if (allInts.isValid && !allInts.getBoolean) {
throw new NumberFormatException(GpuCast.INVALID_INPUT_MESSAGE)
}
}
sanitized.incRefCount()
}
sanitized
}
} else {
// truncate strings that represent decimals, so that we just look at the string before the dot
withResource(sanitized) { _ =>
} else {
// truncate strings that represent decimals to just look at the string before the dot
withResource(Scalar.fromString(".")) { dot =>
withResource(sanitized.stringContains(dot)) { hasDot =>
// only do the decimal sanitization if any strings do contain dot
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,16 +109,17 @@ class RapidsShuffleServerSuite extends RapidsShuffleTestHelper with Arm {
val (handler, mockBuffers, numCloses) = setupMocks(deviceBuffers)
withResource(new BufferSendState(mockTx, bounceBuffer, handler)) { bss =>
assert(bss.hasMoreSends)
val mb = bss.getBufferToSend()
val receiveBlocks = receiveWindow.next()
compareRanges(bounceBuffer, receiveBlocks)
assertResult(10000)(mb.getLength)
assert(!bss.hasMoreSends)
bss.releaseAcquiredToCatalog()
mockBuffers.foreach { b: RapidsBuffer =>
// should have seen 2 closes, one for BufferSendState acquiring for metadata
// and the second acquisition for copying
verify(b, times(numCloses.get(b))).close()
withResource(bss.getBufferToSend()) { mb =>
val receiveBlocks = receiveWindow.next()
compareRanges(bounceBuffer, receiveBlocks)
assertResult(10000)(mb.getLength)
assert(!bss.hasMoreSends)
bss.releaseAcquiredToCatalog()
mockBuffers.foreach { b: RapidsBuffer =>
// should have seen 2 closes, one for BufferSendState acquiring for metadata
// and the second acquisition for copying
verify(b, times(numCloses.get(b))).close()
}
}
}
}
Expand All @@ -139,17 +140,19 @@ class RapidsShuffleServerSuite extends RapidsShuffleTestHelper with Arm {
val receiveWindow = new WindowedBlockIterator[MockBlockWithSize](receiveSide, 10000)
val (handler, mockBuffers, numCloses) = setupMocks(deviceBuffers)
withResource(new BufferSendState(mockTx, bounceBuffer, handler)) { bss =>
var buffs = bss.getBufferToSend()
var receiveBlocks = receiveWindow.next()
compareRanges(bounceBuffer, receiveBlocks)
assert(bss.hasMoreSends)
bss.releaseAcquiredToCatalog()
withResource(bss.getBufferToSend()) { _ =>
val receiveBlocks = receiveWindow.next()
compareRanges(bounceBuffer, receiveBlocks)
assert(bss.hasMoreSends)
bss.releaseAcquiredToCatalog()
}

buffs = bss.getBufferToSend()
receiveBlocks = receiveWindow.next()
compareRanges(bounceBuffer, receiveBlocks)
assert(!bss.hasMoreSends)
bss.releaseAcquiredToCatalog()
withResource(bss.getBufferToSend()) { _ =>
val receiveBlocks = receiveWindow.next()
compareRanges(bounceBuffer, receiveBlocks)
assert(!bss.hasMoreSends)
bss.releaseAcquiredToCatalog()
}

mockBuffers.foreach { b: RapidsBuffer =>
// should have seen 2 closes, one for BufferSendState acquiring for metadata
Expand Down Expand Up @@ -177,10 +180,11 @@ class RapidsShuffleServerSuite extends RapidsShuffleTestHelper with Arm {
val receiveWindow = new WindowedBlockIterator[MockBlockWithSize](receiveSide, 10000)
withResource(new BufferSendState(mockTx, bounceBuffer, handler)) { bss =>
(0 until 246).foreach { _ =>
bss.getBufferToSend()
val receiveBlocks = receiveWindow.next()
compareRanges(bounceBuffer, receiveBlocks)
bss.releaseAcquiredToCatalog()
withResource(bss.getBufferToSend()) { _ =>
val receiveBlocks = receiveWindow.next()
compareRanges(bounceBuffer, receiveBlocks)
bss.releaseAcquiredToCatalog()
}
}
assert(!bss.hasMoreSends)
}
Expand Down

0 comments on commit 03aab1f

Please sign in to comment.