Skip to content

Commit

Permalink
transpile no longer uses option for first return type
Browse files Browse the repository at this point in the history
Signed-off-by: Andy Grove <andygrove@nvidia.com>
  • Loading branch information
andygrove committed May 17, 2022
1 parent 3022e04 commit d6da643
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class GpuRegExpReplaceMeta(
try {
val (pat, repl) =
new CudfRegexTranspiler(RegexReplaceMode).transpile(s.toString, replacement)
pattern = pat
pattern = Some(pat)
repl.map(GpuRegExpUtils.backrefConversion).foreach {
case (hasBackref, convertedRep) =>
containsBackref = hasBackref
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -548,28 +548,21 @@ class CudfRegexTranspiler(mode: RegexMode) {
* Parse Java regular expression and translate into cuDF regular expression.
*
* @param pattern Regular expression that is valid in Java's engine
* @param repl Optional replacement pattern
* @return Regular expression in cuDF format
*/
def transpile(pattern: String, repl: Option[String]): (Option[String], Option[String]) = {
def transpile(pattern: String, repl: Option[String]): (String, Option[String]) = {
// parse the source regular expression
val regex = new RegexParser(pattern).parse()
// if we have a replacement, parse the replacement string using the regex parser to account
// for backrefs
val replacement = repl match {
case Some(s) => Some(new RegexParser(s).parseReplacement(countCaptureGroups(regex)))
case None => None
}
val replacement = repl.map(s => new RegexParser(s).parseReplacement(countCaptureGroups(regex)))

// validate that the regex is supported by cuDF
val cudfRegex = rewrite(regex, replacement, None)
// write out to regex string, performing minor transformations
// such as adding additional escaping
replacement match {
case Some(replaceAST) =>
(Some(cudfRegex.toRegexString), Some(replaceAST.toRegexString))
case _ =>
(Some(cudfRegex.toRegexString), None)
}
(cudfRegex.toRegexString, replacement.map(_.toRegexString))
}

def transpileToSplittableString(e: RegexAST): Option[String] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -853,7 +853,7 @@ class GpuRLikeMeta(
case Literal(str: UTF8String, DataTypes.StringType) if str != null =>
try {
// verify that we support this regex and can transpile it to cuDF format
pattern = (new CudfRegexTranspiler(RegexFindMode).transpile(str.toString, None))._1
pattern = Some(new CudfRegexTranspiler(RegexFindMode).transpile(str.toString, None)._1)
} catch {
case e: RegexUnsupportedException =>
willNotWorkOnGpu(e.getMessage)
Expand Down Expand Up @@ -1023,8 +1023,8 @@ class GpuRegExpExtractMeta(
try {
val javaRegexpPattern = str.toString
// verify that we support this regex and can transpile it to cuDF format
pattern = new CudfRegexTranspiler(RegexFindMode)
.transpile(javaRegexpPattern, None)._1
pattern = Some(new CudfRegexTranspiler(RegexFindMode)
.transpile(javaRegexpPattern, None)._1)
numGroups = countGroups(new RegexParser(javaRegexpPattern).parse())
} catch {
case e: RegexUnsupportedException =>
Expand Down Expand Up @@ -1388,7 +1388,7 @@ abstract class StringSplitRegExpMeta[INPUT <: TernaryExpression](expr: INPUT,
pattern = simplified
case None =>
try {
pattern = transpiler.transpile(utf8Str.toString, None)._1.get
pattern = transpiler.transpile(utf8Str.toString, None)._1
isRegExp = true
} catch {
case e: RegexUnsupportedException =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
val patterns = Seq("a[-b]", "a[+-]", "a[-+]", "a[-]", "a[^-]")
val expected = Seq(raw"a[\-b]", raw"a[+\-]", raw"a[\-+]", raw"a[\-]", "a(?:[\r\n]|[^\\-])")
val transpiler = new CudfRegexTranspiler(RegexFindMode)
val transpiled = patterns.map(transpiler.transpile(_, None)._1.get)
val transpiled = patterns.map(transpiler.transpile(_, None)._1)
assert(transpiled === expected)
}

Expand Down Expand Up @@ -612,7 +612,7 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
val (isRegex, cudfPattern) = if (RegexParser.isRegExpString(pattern)) {
transpiler.transpileToSplittableString(pattern) match {
case Some(simplified) => (false, simplified)
case _ => (true, transpiler.transpile(pattern, None)._1.get)
case _ => (true, transpiler.transpile(pattern, None)._1)
}
} else {
(false, pattern)
Expand Down Expand Up @@ -669,7 +669,7 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
for (javaPattern <- javaPatterns) {
val cpu = cpuContains(javaPattern, input)
val cudfPattern =
(new CudfRegexTranspiler(RegexFindMode)).transpile(javaPattern, None)._1.get
new CudfRegexTranspiler(RegexFindMode).transpile(javaPattern, None)._1
val gpu = try {
gpuContains(cudfPattern, input)
} catch {
Expand All @@ -696,17 +696,17 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
(new CudfRegexTranspiler(RegexReplaceMode)).transpile(javaPattern,
Some(REPLACE_STRING))
val gpu = try {
gpuReplace(cudfPattern.get, replaceString.get, input)
gpuReplace(cudfPattern, replaceString.get, input)
} catch {
case e: CudfException =>
fail(s"cuDF failed to compile pattern: ${toReadableString(cudfPattern.get)}, " +
fail(s"cuDF failed to compile pattern: ${toReadableString(cudfPattern)}, " +
s"original: ${toReadableString(javaPattern)}, " +
s"replacement: ${toReadableString(replaceString.get)}", e)
}
for (i <- input.indices) {
if (cpu(i) != gpu(i)) {
fail(s"javaPattern=${toReadableString(javaPattern)}, " +
s"cudfPattern=${toReadableString(cudfPattern.get)}, " +
s"cudfPattern=${toReadableString(cudfPattern)}, " +
s"input='${toReadableString(input(i))}', " +
s"cpu=${toReadableString(cpu(i))}, " +
s"gpu=${toReadableString(gpu(i))}")
Expand Down Expand Up @@ -806,9 +806,9 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
private def transpile(pattern: String, mode: RegexMode): String = {
mode match {
case RegexReplaceMode =>
new CudfRegexTranspiler(mode).transpile(pattern, Some(REPLACE_STRING))._1.get
new CudfRegexTranspiler(mode).transpile(pattern, Some(REPLACE_STRING))._1
case _ =>
new CudfRegexTranspiler(mode).transpile(pattern, None)._1.get
new CudfRegexTranspiler(mode).transpile(pattern, None)._1

}
}
Expand Down

0 comments on commit d6da643

Please sign in to comment.