Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Patterns such (3?)+ should now fall back to CPU #4715

Merged
merged 10 commits into from
Feb 18, 2022
78 changes: 58 additions & 20 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,42 @@ class CudfRegexTranspiler(mode: RegexMode) {
cudfRegex.toRegexString
}

private def isSomeRepetition(f: ListBuffer[RegexAST] => Boolean)(e: RegexAST): Boolean = {
e match {
case RegexRepetition(_, _) => true
case RegexGroup(_, term) => isRepetition(term)
case RegexSequence(parts) if f(parts) => isRepetition(parts.last)
case _ => false
}
}

private def isRepetition = isSomeRepetition(_.nonEmpty)(_)
private def isNestedRepetition = isSomeRepetition(_.length == 1)(_)

private def isSupportedRepetitionBase(e: RegexAST): Boolean = {
e match {
case RegexEscaped(ch) if ch != 'd' && ch != 'w' => // example: "\B?"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Comments should be consistent

Suggested change
case RegexEscaped(ch) if ch != 'd' && ch != 'w' => // example: "\B?"
case RegexEscaped(ch) if ch != 'd' && ch != 'w' =>
// example: "\B?"

false

case RegexChar(a) if "$^".contains(a) =>
// example: "$*"
false

case RegexRepetition(_, _) =>
// example: "a*+"
false

case RegexSequence(parts) =>
parts.forall(x => isSupportedRepetitionBase(x))
NVnavkumar marked this conversation as resolved.
Show resolved Hide resolved

case RegexGroup(_, term) =>
isSupportedRepetitionBase(term)

case _ => true
}
}


private def rewrite(regex: RegexAST): RegexAST = {
regex match {

Expand Down Expand Up @@ -628,20 +664,29 @@ class CudfRegexTranspiler(mode: RegexMode) {
throw new RegexUnsupportedException(
"regexp_replace on GPU does not support repetition with ? or *")

case (RegexEscaped(ch), _) if ch != 'd' && ch != 'w' =>
// example: "\B?"
throw new RegexUnsupportedException(nothingToRepeat)
case (_, QuantifierVariableLength(0,None)) if mode == RegexReplaceMode =>
// see https://github.com/NVIDIA/spark-rapids/issues/4468
throw new RegexUnsupportedException(
"regexp_replace on GPU does not support repetition with {0,}")

case (RegexChar(a), _) if "$^".contains(a) =>
// example: "$*"
throw new RegexUnsupportedException(nothingToRepeat)
case (_, QuantifierFixedLength(0)) | (_, QuantifierVariableLength(0,Some(0)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
case (_, QuantifierFixedLength(0)) | (_, QuantifierVariableLength(0,Some(0)))
case (_, QuantifierFixedLength(0)) | (_, QuantifierVariableLength(0, Some(0)))

if mode != RegexFindMode =>
throw new RegexUnsupportedException(
"regex_replace and regex_split on GPU do not support repetition with {0} or {0,0}")

case (RegexRepetition(_, _), _) =>
// example: "a*+"
case (RegexGroup(_, term), SimpleQuantifier(ch))
if "+*".contains(ch) && !isSupportedRepetitionBase(term) =>
throw new RegexUnsupportedException(nothingToRepeat)

case _ =>
case (RegexGroup(_, term), QuantifierVariableLength(_,None))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
case (RegexGroup(_, term), QuantifierVariableLength(_,None))
case (RegexGroup(_, term), QuantifierVariableLength(_, None))

if !isSupportedRepetitionBase(term) =>
// specifically this variable length repetition: \A{2,}
NVnavkumar marked this conversation as resolved.
Show resolved Hide resolved
throw new RegexUnsupportedException(nothingToRepeat)
case (RegexGroup(_, _), SimpleQuantifier(ch)) if ch == '?' =>
RegexRepetition(rewrite(base), quantifier)
case _ if isSupportedRepetitionBase(base) =>
RegexRepetition(rewrite(base), quantifier)
case _ =>
throw new RegexUnsupportedException(nothingToRepeat)

}

Expand All @@ -650,14 +695,6 @@ class CudfRegexTranspiler(mode: RegexMode) {
val rr = rewrite(r)

// cuDF does not support repetition on one side of a choice, such as "a*|a"
def isRepetition(e: RegexAST): Boolean = {
e match {
case RegexRepetition(_, _) => true
case RegexGroup(_, term) => isRepetition(term)
case RegexSequence(parts) if parts.nonEmpty => isRepetition(parts.last)
case _ => false
}
}
if (isRepetition(ll) || isRepetition(rr)) {
throw new RegexUnsupportedException(nothingToRepeat)
}
Expand All @@ -667,8 +704,9 @@ class CudfRegexTranspiler(mode: RegexMode) {
def endsWithLineAnchor(e: RegexAST): Boolean = {
e match {
case RegexSequence(parts) if parts.nonEmpty =>
isBeginOrEndLineAnchor(parts.last)
case _ => false
endsWithLineAnchor(parts.last)
case RegexEscaped('A') => true
case _ => isBeginOrEndLineAnchor(e)
}
}
if (endsWithLineAnchor(ll) || endsWithLineAnchor(rr)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,21 @@ class RegularExpressionParserSuite extends FunSuite {
RegexSequence(ListBuffer(RegexOctalChar("047"), RegexChar('7'))))
}

test("repetition with group containing simple repetition") {
assert(parse("(3?)+") ===
RegexSequence(ListBuffer(RegexRepetition(RegexGroup(capture = true,
RegexSequence(ListBuffer(RegexRepetition(RegexChar('3'),
SimpleQuantifier('?'))))),SimpleQuantifier('+')))))
}

test("repetition with group containing escape character") {
assert(parse(raw"(\A)+") ===
RegexSequence(ListBuffer(RegexRepetition(RegexGroup(capture = true,
RegexSequence(ListBuffer(RegexEscaped('A')))),
SimpleQuantifier('+'))))
)
}

test("group containing choice with repetition") {
assert(parse("(\t+|a)") == RegexSequence(ListBuffer(
RegexGroup(capture = true, RegexChoice(RegexSequence(ListBuffer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package com.nvidia.spark.rapids


import java.util.regex.Pattern

import scala.collection.mutable.{HashSet, ListBuffer}
Expand Down Expand Up @@ -42,7 +43,8 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
"a^|b",
"w$|b",
"\n[^\r\n]x*|^3x",
"]*\\wWW$|zb"
"]*\\wWW$|zb",
"(\\A|\\05)?"
)
// data is not relevant because we are checking for compilation errors
val inputs = Seq("a")
Expand Down Expand Up @@ -119,6 +121,13 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
)
}

test("cuDF does not support single repetition both inside and outside of capture groups") {
// see https://github.com/NVIDIA/spark-rapids/issues/4487
val patterns = Seq("(3?)+", "(3?)*", "(3*)+", "((3?))+")
patterns.foreach(pattern =>
assertUnsupported(pattern, RegexFindMode, "nothing to repeat"))
}

test("cuDF does not support OR at BOL / EOL") {
val patterns = Seq("$|a", "^|a")
patterns.foreach(pattern => {
Expand Down Expand Up @@ -171,6 +180,13 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
"\ntest", "test\n", "\ntest\n"))
}

test("string anchor \\A will fall back to CPU in some repetitions") {
val patterns = Seq(raw"(\A)+", raw"(\A)*", raw"(\A){2,}")
patterns.foreach(pattern =>
assertUnsupported(pattern, RegexFindMode, "nothing to repeat")
)
}

test("string anchor \\Z fall back to CPU") {
for (mode <- Seq(RegexFindMode, RegexReplaceMode)) {
assertUnsupported("\\Z", mode, "string anchor \\Z is not supported")
Expand Down Expand Up @@ -294,6 +310,40 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
assertCpuGpuMatchesRegexpReplace(patterns, inputs)
}

test("regexp_replace - character class repetition - ? and * - fall back to CPU") {
val patterns = Seq(raw"[1a-zA-Z]?", raw"[1a-zA-Z]*")
patterns.foreach(pattern =>
assertUnsupported(pattern, RegexReplaceMode,
"regexp_replace on GPU does not support repetition with ? or *"
)
)
}

test("regexp_replace - character class repetition - {0,} - fall back to CPU") {
val patterns = Seq(raw"[1a-zA-Z]{0,}")
patterns.foreach(pattern =>
assertUnsupported(pattern, RegexReplaceMode,
"regexp_replace on GPU does not support repetition with {0,}"
)
)
}

test("regexp_replace - fall back to CPU for {0} or {0,0}") {
val patterns = Seq("a{0}", raw"\02{0}", "a{0,0}", raw"\02{0,0}")
patterns.foreach(pattern =>
assertUnsupported(pattern, RegexReplaceMode,
"regex_replace and regex_split on GPU do not support repetition with {0} or {0,0}")
)
}

test("regexp_split - fall back to CPU for {0} or {0,0}") {
val patterns = Seq("a{0}", raw"\02{0}", "a{0,0}", raw"\02{0,0}")
patterns.foreach(pattern =>
assertUnsupported(pattern, RegexSplitMode,
"regex_replace and regex_split on GPU do not support repetition with {0} or {0,0}")
)
}

test("compare CPU and GPU: regexp find fuzz test with limited chars") {
// testing with this limited set of characters finds issues much
// faster than using the full ASCII set
Expand Down