Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flatten simple 4+ nesting of withResource #6833

Merged
merged 31 commits into from
Oct 25, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
87fc915
wip
gerashegalov Oct 17, 2022
8d14610
Flatten simple 4+ nesting of withResource
gerashegalov Oct 18, 2022
70d049a
andThen
gerashegalov Oct 18, 2022
68e1427
parquet
gerashegalov Oct 18, 2022
15b32a9
wip
gerashegalov Oct 18, 2022
b3facf3
wip
gerashegalov Oct 18, 2022
c60d56a
wip
gerashegalov Oct 18, 2022
5edaa88
scalastyle and spark330
gerashegalov Oct 18, 2022
ce30bc3
unleak
gerashegalov Oct 18, 2022
b936c41
review
gerashegalov Oct 19, 2022
5c28444
Fix reduceLeft
gerashegalov Oct 19, 2022
e3b56b3
revert stringFunctions
gerashegalov Oct 19, 2022
ee3ed4c
Merge branch 'branch-22.12' into reduceWithResourceScope
gerashegalov Oct 19, 2022
b229af7
reformat for review
gerashegalov Oct 19, 2022
0c26da0
more reviews
gerashegalov Oct 19, 2022
aaad636
more formatting
gerashegalov Oct 19, 2022
27a9580
scalastyle
gerashegalov Oct 19, 2022
6f04fd0
Parquet condition flipped
gerashegalov Oct 20, 2022
fdda831
Merge remote-tracking branch 'origin/branch-22.12' into reduceWithRes…
gerashegalov Oct 20, 2022
bc7424d
restore RapidsErrorUtis
gerashegalov Oct 20, 2022
9d95272
Merge remote-tracking branch 'origin/branch-22.12' into reduceWithRes…
gerashegalov Oct 21, 2022
9164def
WIP on detangling withresource
gerashegalov Oct 21, 2022
7506ce5
more unnesting
gerashegalov Oct 21, 2022
537751a
unnest more
gerashegalov Oct 21, 2022
d12b773
wip
gerashegalov Oct 22, 2022
cad5763
regex
gerashegalov Oct 22, 2022
9b32ba3
review
gerashegalov Oct 24, 2022
00127ad
revews
gerashegalov Oct 25, 2022
fbdf705
Merge remote-tracking branch 'gerashegalov/reduceWithResourceScope' i…
gerashegalov Oct 25, 2022
7056c94
fix double close
gerashegalov Oct 25, 2022
050e3ad
another leak
gerashegalov Oct 25, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -292,17 +292,16 @@ object GpuIntervalUtils extends Arm {
// not close firstSignInTable and secondSignInTable here, outer table.close will close them
private def finalSign(
firstSignInTable: ColumnVector, secondSignInTable: ColumnVector): ColumnVector = {
withResource(Scalar.fromString("-")) { negScalar =>
withResource(negScalar.equalTo(firstSignInTable)) { neg1 =>
withResource(negScalar.equalTo(secondSignInTable)) { neg2 =>
withResource(neg1.bitXor(neg2)) { s =>
withResource(Scalar.fromLong(1L)) { one =>
withResource(Scalar.fromLong(-1L)) { negOne =>
s.ifElse(negOne, one)
}
}
}
}
val negatives = withResource(Scalar.fromString("-")) { negScalar =>
lazyReduce(Seq(
lazyResource(negScalar.equalTo(firstSignInTable)),
lazyResource(negScalar.equalTo(secondSignInTable))
))(_ bitXor _)
gerashegalov marked this conversation as resolved.
Show resolved Hide resolved
}

withResource(negatives()) { s =>
withResource(Scalar.fromLong(1L)) { posOne =>
withResource(Scalar.fromLong(-1L))(negOne => s.ifElse(negOne, posOne))
gerashegalov marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
Expand All @@ -315,15 +314,12 @@ object GpuIntervalUtils extends Arm {
* @return micros column
*/
private def getMicrosFromDecimal(sign: ColumnVector, decimal: ColumnVector): ColumnVector = {
withResource(decimal.castTo(DType.create(DType.DTypeEnum.DECIMAL64, -6))) { decimal =>
withResource(Scalar.fromLong(1000000L)) { million =>
withResource(decimal.mul(million)) { r =>
withResource(r.asLongs()) { l =>
l.mul(sign)
}
}
}
}
val timesMillion = lazyReduce(Seq(
gerashegalov marked this conversation as resolved.
Show resolved Hide resolved
lazyResource(decimal.castTo(DType.create(DType.DTypeEnum.DECIMAL64, -6))),
lazyResource(Scalar.fromLong(1000000L))))(_ mul _)

val timesMillionLongs = withResource(timesMillion().asInstanceOf[ColumnVector])(_.asLongs())
withResource(timesMillionLongs)(_ mul sign)
}

private def addFromDayToDay(
Expand Down
10 changes: 3 additions & 7 deletions sql-plugin/src/main/scala/com/nvidia/spark/RebaseHelper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,9 @@ object RebaseHelper extends Arm {
// https://github.com/NVIDIA/spark-rapids/issues/1126
val dtype = column.getType
if (dtype == DType.TIMESTAMP_DAYS) {
withResource(Scalar.timestampDaysFromInt(startDay)) { minGood =>
withResource(column.lessThan(minGood)) { hasBad =>
withResource(hasBad.any()) { a =>
a.isValid && a.getBoolean
}
}
}
val hasBad = withResource(Scalar.timestampDaysFromInt(startDay))(column.lessThan)
val anyBad = withResource(hasBad)(_.any())
withResource(anyBad)(_ => anyBad.isValid && anyBad.getBoolean)
} else {
false
}
Expand Down
22 changes: 22 additions & 0 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,28 @@ trait Arm {
h.close()
}
}


def lazyResource[T <: AutoCloseable](x: => T) = () => x

def lazyReduce[T <: AutoCloseable](
lazySeq: Seq[() => T])(
func: Function2[T, T, T]
): () => T = {
lazySeq.reduceLeft { (lazy1, lazy2) =>
lazyResource {
withResource(lazy1()) { lzv1 =>
withResource(lazy2())(lzv2 => func(lzv1, lzv2))
}
}
}
}

def lazyAndThen[T <: AutoCloseable, R <: AutoCloseable](
x: () => T
)(func: Function[T, R]): () => R = {
lazyResource(withResource(x())(func))
}
}

class CloseableHolder[T <: AutoCloseable](var t: T) {
Expand Down
27 changes: 12 additions & 15 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCSVScan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -340,21 +340,18 @@ class CSVPartitionReader(
* CSV supports "true" and "false" (case-insensitive) as valid boolean values.
*/
override def castStringToBool(input: ColumnVector): ColumnVector = {
withResource(input.strip()) { stripped =>
withResource(stripped.lower()) { lower =>
withResource(Scalar.fromString("true")) { t =>
withResource(Scalar.fromString("false")) { f =>
withResource(lower.equalTo(t)) { isTrue =>
withResource(lower.equalTo(f)) { isFalse =>
withResource(isTrue.or(isFalse)) { isValidBool =>
withResource(Scalar.fromNull(DType.BOOL8)) { nullBool =>
isValidBool.ifElse(isTrue, nullBool)
}
}
}
}
}
}
val lowerStripped = withResource(input.strip())(_.lower())
val isTrue = closeOnExcept(lowerStripped) { _ =>
gerashegalov marked this conversation as resolved.
Show resolved Hide resolved
withResource(Scalar.fromString(true.toString))(lowerStripped.equalTo)
}
val isValidBool = withResource(lowerStripped) { _ =>
withResource(ColumnVector.fromStrings(true.toString, false.toString)) { boolStrings =>
lowerStripped.contains(boolStrings)
}
}
withResource(isValidBool) { _ =>
withResource(Scalar.fromNull(DType.BOOL8)) { nullBool =>
isValidBool.ifElse(isTrue, nullBool)
}
}
}
Expand Down
162 changes: 62 additions & 100 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -783,20 +783,13 @@ object GpuCast extends Arm {
else {strCol.incRefCount}
}

withResource(ColumnVector.fromScalar(sep, numRows)) {sepCol =>
withResource(input.getChildColumnView(0)) { childCol =>
withResource(addSpaces(childCol)) {strChildCol =>
withResource(input.replaceListChild(strChildCol)) {strArrayCol =>
withResource(
strArrayCol.stringConcatenateListElements(sepCol)) { strColContainsNull =>
withResource(strColContainsNull.replaceNulls(empty)) {strCol =>
removeFirstSpace(strCol)
}
}
}
}
}
val strChildCol = withResource(input.getChildColumnView(0))(addSpaces)
val strArrayCol = withResource(strChildCol)(input.replaceListChild)
val strColContainsNull = withResource(ColumnVector.fromScalar(sep, numRows)) { sepCol =>
gerashegalov marked this conversation as resolved.
Show resolved Hide resolved
withResource(strArrayCol)(_.stringConcatenateListElements(sepCol))
}
val strCol = withResource(strColContainsNull)(_.replaceNulls(empty))
withResource(strCol)(removeFirstSpace)
}
}

Expand Down Expand Up @@ -831,15 +824,11 @@ object GpuCast extends Arm {
child, elementType, StringType, ansiMode, legacyCastToString, stringToDateAnsiModeEnabled)
}

withResource(strChildContainsNull) {strChildContainsNull =>
withResource(input.replaceListChild(strChildContainsNull)) {strArrayCol =>
withResource(concatenateStringArrayElements(strArrayCol, legacyCastToString)) {strCol =>
withResource(addBrackets(strCol)) {
_.mergeAndSetValidity(BinaryOp.BITWISE_AND, input)
}
}
}
}
val strArrayCol = withResource(strChildContainsNull)(input.replaceListChild)
val strCol = withResource(strArrayCol)(concatenateStringArrayElements(_, legacyCastToString))
val strColWithBrackets = withResource(strCol)(addBrackets)

withResource(strColWithBrackets)( _.mergeAndSetValidity(BinaryOp.BITWISE_AND, input))
jlowe marked this conversation as resolved.
Show resolved Hide resolved
}
}

Expand Down Expand Up @@ -897,20 +886,17 @@ object GpuCast extends Arm {
}

// concatenate elements
withResource(strElements) {strElements =>
withResource(input.replaceListChild(strElements)) {strArrayCol =>
withResource(concatenateStringArrayElements(strArrayCol, legacyCastToString)) {strCol =>
withResource(
Seq(leftScalar, rightScalar).safeMap(ColumnVector.fromScalar(_, numRows))
) {case Seq(leftCol, rightCol) =>
withResource(ColumnVector.stringConcatenate(
emptyScalar, nullScalar, Array(leftCol, strCol, rightCol))) {
_.mergeAndSetValidity(BinaryOp.BITWISE_AND, input)
}
}
}
}
val strArrayCol = withResource(strElements)(input.replaceListChild)
val strCol = withResource(strArrayCol)(concatenateStringArrayElements(_, legacyCastToString))
val Seq(leftCol, rightCol) = closeOnExcept(strCol) { _ =>
Seq(leftScalar, rightScalar).safeMap(ColumnVector.fromScalar(_, numRows))
}
val cols = Array[ColumnView](leftCol, strCol, rightCol)
val concatCol = withResource(cols) { _ =>
ColumnVector.stringConcatenate(emptyScalar, nullScalar, cols)
}

withResource(concatCol)( _.mergeAndSetValidity(BinaryOp.BITWISE_AND, input))
}
}

Expand Down Expand Up @@ -1006,32 +992,27 @@ object GpuCast extends Arm {
val trueStrings = Seq("t", "true", "y", "yes", "1")
val falseStrings = Seq("f", "false", "n", "no", "0")
val boolStrings = trueStrings ++ falseStrings

// determine which values are valid bool strings
withResource(ColumnVector.fromStrings(boolStrings: _*)) { boolStrings =>
withResource(input.strip()) { stripped =>
withResource(stripped.lower()) { lower =>
withResource(lower.contains(boolStrings)) { validBools =>
// in ansi mode, fail if any values are not valid bool strings
if (ansiEnabled) {
withResource(validBools.all()) { isAllBool =>
if (isAllBool.isValid && !isAllBool.getBoolean) {
throw new IllegalStateException(GpuCast.INVALID_INPUT_MESSAGE)
}
}
}
// replace non-boolean values with null
withResource(Scalar.fromNull(DType.STRING)) { nullString =>
withResource(validBools.ifElse(lower, nullString)) { sanitizedInput =>
// return true, false, or null, as appropriate
withResource(ColumnVector.fromStrings(trueStrings: _*)) { cvTrue =>
sanitizedInput.contains(cvTrue)
}
val lowerStripped = withResource(input.strip())(_.lower())
val sanitizedInput = withResource(lowerStripped) { _ =>
withResource(lowerStripped.contains(boolStrings)) { validBools =>
// in ansi mode, fail if any values are not valid bool strings
if (ansiEnabled) {
withResource(validBools.all()) { isAllBool =>
if (isAllBool.isValid && !isAllBool.getBoolean) {
throw new IllegalStateException(GpuCast.INVALID_INPUT_MESSAGE)
}
}
}
// replace non-boolean values with null
withResource(Scalar.fromNull(DType.STRING))(validBools.ifElse(lowerStripped, _))
}
}
withResource(sanitizedInput) { _ =>
// return true, false, or null, as appropriate
withResource(ColumnVector.fromStrings(trueStrings: _*))(sanitizedInput.contains)
}
}
}

Expand Down Expand Up @@ -1072,14 +1053,11 @@ object GpuCast extends Arm {
}
}
}
withResource(sanitized.castTo(dType)) { casted =>
withResource(Scalar.fromNull(dType)) { nulls =>
withResource(isFloat.ifElse(casted, nulls)) { floatsOnly =>
withResource(FloatUtils.getNanScalar(dType)) { nan =>
isNan.ifElse(nan, floatsOnly)
}
}
}
val floatsOnly = withResource(sanitized.castTo(dType)) { casted =>
withResource(Scalar.fromNull(dType))(isFloat.ifElse(casted, _))
}
withResource(floatsOnly) { _ =>
withResource(FloatUtils.getNanScalar(dType))(isNan.ifElse(_, floatsOnly))
}
}
}
Expand Down Expand Up @@ -1131,16 +1109,15 @@ object GpuCast extends Arm {

private def checkResultForAnsiMode(input: ColumnVector, result: ColumnVector,
errMessage: String): ColumnVector = {
closeOnExcept(result) { finalResult =>
withResource(input.isNotNull) { wasNotNull =>
withResource(finalResult.isNull) { isNull =>
withResource(wasNotNull.and(isNull)) { notConverted =>
withResource(notConverted.any()) { notConvertedAny =>
if (notConvertedAny.isValid && notConvertedAny.getBoolean) {
throw new DateTimeException(errMessage)
}
}
}
closeOnExcept(result) { _ =>
val notConverted = lazyReduce(Seq(
lazyResource(input.isNotNull()),
lazyResource(result.isNull())
))(_ and _)
val notConvertedAny = lazyAndThen(notConverted)(_.any())
withResource(notConvertedAny()) { x =>
if (x.isValid && x.getBoolean) {
throw new DateTimeException(errMessage)
}
}
}
Expand Down Expand Up @@ -1244,29 +1221,18 @@ object GpuCast extends Arm {
withResource(orElse) { orElse =>

// valid dates must match the regex and either of the cuDF formats
val isCudfMatch = withResource(input.isTimestamp(cudfFormat1)) { isTimestamp1 =>
withResource(input.isTimestamp(cudfFormat2)) { isTimestamp2 =>
withResource(input.isTimestamp(cudfFormat3)) { isTimestamp3 =>
withResource(input.isTimestamp(cudfFormat4)) { isTimestamp4 =>
withResource(isTimestamp1.or(isTimestamp2)) { isTimestamp12 =>
withResource(isTimestamp12.or(isTimestamp3)) { isTimestamp123 =>
isTimestamp123.or(isTimestamp4)
}
}
}
}
}
}
val isCudfMatch = lazyReduce(
Seq(cudfFormat1, cudfFormat2, cudfFormat3, cudfFormat4)
.map(f => (() => input.isTimestamp(f)))
)(_ or _)

val isValidTimestamp = withResource(isCudfMatch) { isCudfMatch =>
withResource(input.matchesRe(TIMESTAMP_REGEX_FULL)) { isRegexMatch =>
isCudfMatch.and(isRegexMatch)
}
}
val isValidTimestamp = lazyReduce(
Seq(isCudfMatch, () => input.matchesRe(TIMESTAMP_REGEX_FULL))
)(_ and _)

// we only need to parse with one of the cuDF formats because the parsing code ignores
// the ' ' or 'T' between the date and time components
withResource(isValidTimestamp) { _ =>
withResource(isValidTimestamp()) { isValidTimestamp =>
withResource(input.asTimestampMicroseconds(cudfFormat1)) { asDays =>
isValidTimestamp.ifElse(asDays, orElse)
}
Expand Down Expand Up @@ -1576,17 +1542,13 @@ object GpuCast extends Arm {
}
}

withResource(updatedMaxRet) { updatedMax =>
withResource(Scalar.fromLong(minSeconds)) { minSecondsS =>
withResource(longInput.lessThan(minSecondsS)) { lessThanMinSeconds =>
withResource(Scalar.fromLong(Long.MinValue)) { longMinS =>
withResource(lessThanMinSeconds.ifElse(longMinS, updatedMax)) { cv =>
cv.castTo(GpuColumnVector.getNonNestedRapidsType(toType))
}
}
}
val cv = withResource(updatedMaxRet) { updatedMax =>
withResource(Seq(minSeconds, Long.MinValue).safeMap(Scalar.fromLong)) {
case Seq(minSecondsS, longMinS) =>
withResource(longInput.lessThan(minSecondsS))(_.ifElse(longMinS, updatedMax))
}
}
withResource(cv)(_.castTo(GpuColumnVector.getNonNestedRapidsType(toType)))
}
}

Expand Down
Loading