From ade65915c988546dc20c9c937a5f5e3f2c7456e7 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 16 Nov 2021 17:46:48 -0700 Subject: [PATCH] Regexp_replace support regexp [databricks] (#4063) * Integrate regex parser and transpiler with regexp_replace Signed-off-by: Andy Grove * revert change * fix shim builds * fix regression in qa_nightly_test * update compatibility docs * fix compilation error in 31xdb shim * attempt to fix 312db compilation error * Restore legacy behavior of using GpuStringReplace if the regex pattern is a literal string * fix typo * fix import in 301db shim * Update docs/compatibility.md Co-authored-by: Jason Lowe * Update docs/compatibility.md Co-authored-by: Jason Lowe * Improve willNotWorkOnGpu messages * Improve willNotWorkOnGpu messages * Improve willNotWorkOnGpu messages * update generated docs * fix regression * remove unused imports * remove unused imports * remove unused imports * remove unused imports * remove unused imports Co-authored-by: Jason Lowe --- docs/compatibility.md | 41 +++-- docs/configs.md | 4 +- docs/supported_ops.md | 6 +- .../src/main/python/qa_nightly_select_test.py | 3 +- .../src/main/python/string_test.py | 17 +- .../shims/v2/GpuRegExpReplaceExec.scala | 68 ++++++++ .../rapids/shims/v2/SparkBaseShims.scala | 19 +-- .../shims/v2/GpuRegExpReplaceMeta.scala | 68 ++++++++ .../rapids/shims/v2/SparkBaseShims.scala | 19 +-- .../shims/v2/GpuRegExpReplaceExec.scala | 79 +++++++++ .../rapids/shims/v2/SparkBaseShims.scala | 27 +-- .../shims/v2/GpuRegExpReplaceExec.scala | 79 +++++++++ .../rapids/shims/v2/SparkBaseShims.scala | 27 +-- .../spark/rapids/shims/v2/Spark32XShims.scala | 27 +-- .../nvidia/spark/rapids/GpuOverrides.scala | 4 +- .../com/nvidia/spark/rapids/RapidsMeta.scala | 21 ++- .../com/nvidia/spark/rapids/RegexParser.scala | 35 +++- .../spark/sql/rapids/stringFunctions.scala | 99 ++++++++++- .../RegularExpressionTranspilerSuite.scala | 158 ++++++++++++++---- .../spark/rapids/StringFallbackSuite.scala | 35 ++-- 20 files changed, 665 insertions(+), 171 deletions(-) create mode 100644 sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala create mode 100644 sql-plugin/src/main/301until310-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceMeta.scala create mode 100644 sql-plugin/src/main/311+-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala create mode 100644 sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala diff --git a/docs/compatibility.md b/docs/compatibility.md index 6d04732a380..9766d0ad374 100644 --- a/docs/compatibility.md +++ b/docs/compatibility.md @@ -265,22 +265,43 @@ the end of the string. This will be fixed in a future release. The issue is ## Regular Expressions -### regexp_replace +The following Apache Spark regular expression functions and expressions are supported on the GPU: -The RAPIDS Accelerator for Apache Spark currently supports string literal matches, not wildcard -matches for the `regexp_replace` function and will fall back to CPU if a regular expression pattern -is provided. +- `RLIKE` +- `regexp` +- `regexp_like` +- `regexp_replace` -### RLike +These operations are disabled by default because of known incompatibilities between the Java regular expression +engine that Spark uses and the cuDF regular expression engine on the GPU, and also because the regular expression +kernels can potentially have high memory overhead. -The GPU implementation of `RLike` has the following known issues where behavior is not consistent with Apache Spark and -this expression is disabled by default. It can be enabled setting `spark.rapids.sql.expression.RLike=true`. +These operations can be enabled on the GPU with the following configuration settings: -- `$` does not match the end of string if the string ends with a line-terminator +- `spark.rapids.sql.expression.RLike=true` (for `RLIKE`, `regexp`, and `regexp_like`) +- `spark.rapids.sql.expression.RegExpReplace=true` for `regexp_replace` + +Even when these expressions are enabled, there are instances where regular expression operations will fall back to +CPU when the RAPIDS Accelerator determines that a pattern is either unsupported or would produce incorrect results on the GPU. + +Here are some examples of regular expression patterns that are not supported on the GPU and will fall back to the CPU. + +- Lazy quantifiers, such as `a*?` +- Possessive quantifiers, such as `a*+` +- Character classes that use union, intersection, or subtraction semantics, such as `[a-d[m-p]]`, `[a-z&&[def]]`, + or `[a-z&&[^bc]]` +- Word and non-word boundaries, `\b` and `\B` +- Empty groups: `()` +- Regular expressions containing null characters (unless the pattern is a simple literal string) +- Beginning-of-line and end-of-line anchors (`^` and `$`) are not supported in some contexts, such as when combined +- with a choice (`^|a`) or when used anywhere in `regexp_replace` patterns. + +In addition to these cases that can be detected, there is also one known issue that can cause incorrect results: + +- `$` does not match the end of a string if the string ends with a line-terminator ([cuDF issue #9620](https://github.com/rapidsai/cudf/issues/9620)) -`RLike` will fall back to CPU if any regular expressions are detected that are not supported on the GPU -or would produce different results on the GPU. +Work is ongoing to increase the range of regular expressions that can run on the GPU. ## Timestamps diff --git a/docs/configs.md b/docs/configs.md index 4c6781de6d7..d66fe59ffdd 100644 --- a/docs/configs.md +++ b/docs/configs.md @@ -259,10 +259,10 @@ Name | SQL Function(s) | Description | Default Value | Notes spark.rapids.sql.expression.PromotePrecision| |PromotePrecision before arithmetic operations between DecimalType data|true|None| spark.rapids.sql.expression.PythonUDF| |UDF run in an external python process. Does not actually run on the GPU, but the transfer of data to/from it can be accelerated|true|None| spark.rapids.sql.expression.Quarter|`quarter`|Returns the quarter of the year for date, in the range 1 to 4|true|None| -spark.rapids.sql.expression.RLike|`rlike`|RLike|false|This is disabled by default because The GPU implementation of rlike is not compatible with Apache Spark. See the compatibility guide for more information.| +spark.rapids.sql.expression.RLike|`rlike`|RLike|false|This is disabled by default because the implementation is not 100% compatible. See the compatibility guide for more information.| spark.rapids.sql.expression.Rand|`random`, `rand`|Generate a random column with i.i.d. uniformly distributed values in [0, 1)|true|None| spark.rapids.sql.expression.Rank|`rank`|Window function that returns the rank value within the aggregation window|true|None| -spark.rapids.sql.expression.RegExpReplace|`regexp_replace`|RegExpReplace support for string literal input patterns|true|None| +spark.rapids.sql.expression.RegExpReplace|`regexp_replace`|RegExpReplace support for string literal input patterns|false|This is disabled by default because the implementation is not 100% compatible. See the compatibility guide for more information.| spark.rapids.sql.expression.Remainder|`%`, `mod`|Remainder or modulo|true|None| spark.rapids.sql.expression.Rint|`rint`|Rounds up a double value to the nearest double equal to an integer|true|None| spark.rapids.sql.expression.Round|`round`|Round an expression to d decimal places using HALF_UP rounding mode|true|None| diff --git a/docs/supported_ops.md b/docs/supported_ops.md index 1593a151152..4173fe9b3bd 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -9638,7 +9638,7 @@ are limited. RLike `rlike` RLike -This is disabled by default because The GPU implementation of rlike is not compatible with Apache Spark. See the compatibility guide for more information. +This is disabled by default because the implementation is not 100% compatible. See the compatibility guide for more information. project str @@ -9826,7 +9826,7 @@ are limited. RegExpReplace `regexp_replace` RegExpReplace support for string literal input patterns -None +This is disabled by default because the implementation is not 100% compatible. See the compatibility guide for more information. project str @@ -9859,7 +9859,7 @@ are limited. -PS
very limited regex support;
Literal value only
+PS
Literal value only
diff --git a/integration_tests/src/main/python/qa_nightly_select_test.py b/integration_tests/src/main/python/qa_nightly_select_test.py index 4b635e9ccac..fd362de14f6 100644 --- a/integration_tests/src/main/python/qa_nightly_select_test.py +++ b/integration_tests/src/main/python/qa_nightly_select_test.py @@ -150,7 +150,8 @@ def idfn(val): 'spark.rapids.sql.hasNans': 'false', 'spark.rapids.sql.castStringToFloat.enabled': 'true', 'spark.rapids.sql.castFloatToIntegralTypes.enabled': 'true', - 'spark.rapids.sql.castFloatToString.enabled': 'true' + 'spark.rapids.sql.castFloatToString.enabled': 'true', + 'spark.rapids.sql.expression.RegExpReplace': 'true' } _first_last_qa_conf = copy_and_update(_qa_conf, { diff --git a/integration_tests/src/main/python/string_test.py b/integration_tests/src/main/python/string_test.py index cfd0010d2ef..5d300c2eed9 100644 --- a/integration_tests/src/main/python/string_test.py +++ b/integration_tests/src/main/python/string_test.py @@ -337,7 +337,8 @@ def test_re_replace(): 'REGEXP_REPLACE(a, "TEST", "PROD")', 'REGEXP_REPLACE(a, "TEST", "")', 'REGEXP_REPLACE(a, "TEST", "%^[]\ud720")', - 'REGEXP_REPLACE(a, "TEST", NULL)')) + 'REGEXP_REPLACE(a, "TEST", NULL)'), + conf={'spark.rapids.sql.expression.RegExpReplace': 'true'}) def test_re_replace_null(): gen = mk_str_gen('[\u0000 ]{0,2}TE[\u0000 ]{0,2}ST[\u0000 ]{0,2}')\ @@ -356,7 +357,8 @@ def test_re_replace_null(): 'REGEXP_REPLACE(a, "\x00", "NULL")', 'REGEXP_REPLACE(a, "\0", "NULL")', 'REGEXP_REPLACE(a, "TE\u0000ST", "PROD")', - 'REGEXP_REPLACE(a, "TE\u0000\u0000ST", "PROD")')) + 'REGEXP_REPLACE(a, "TE\u0000\u0000ST", "PROD")'), + conf={'spark.rapids.sql.expression.RegExpReplace': 'true'}) def test_length(): gen = mk_str_gen('.{0,5}TEST[\ud720 A]{0,5}') @@ -470,6 +472,17 @@ def test_like_complex_escape(): 'a like "_oo"'), conf={'spark.sql.parser.escapedStringLiterals': 'true'}) +def test_regexp_replace(): + gen = mk_str_gen('[abcd]{0,3}') + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'regexp_replace(a, "a", "A")', + 'regexp_replace(a, "[^xyz]", "A")', + 'regexp_replace(a, "([^x])|([^y])", "A")', + 'regexp_replace(a, "(?:aa)+", "A")', + 'regexp_replace(a, "a|b|c", "A")'), + conf={'spark.rapids.sql.expression.RegExpReplace': 'true'}) + def test_rlike(): gen = mk_str_gen('[abcd]{1,3}') assert_gpu_and_cpu_are_equal_collect( diff --git a/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala b/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala new file mode 100644 index 00000000000..e5faece3781 --- /dev/null +++ b/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.nvidia.spark.rapids.shims.v2 + +import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, RapidsConf, RapidsMeta, RegexUnsupportedException, TernaryExprMeta} + +import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RegExpReplace} +import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuStringReplace} +import org.apache.spark.sql.types.DataTypes +import org.apache.spark.unsafe.types.UTF8String + +class GpuRegExpReplaceMeta( + expr: RegExpReplace, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _, _]], + rule: DataFromReplacementRule) + extends TernaryExprMeta[RegExpReplace](expr, conf, parent, rule) { + + override def tagExprForGpu(): Unit = { + expr.regexp match { + case Literal(null, _) => + willNotWorkOnGpu(s"null pattern is not supported on GPU") + case Literal(s: UTF8String, DataTypes.StringType) => + val pattern = s.toString + if (pattern.isEmpty) { + willNotWorkOnGpu(s"empty pattern is not supported on GPU") + } + + if (GpuOverrides.isSupportedStringReplacePattern(expr.regexp)) { + // use GpuStringReplace + } else { + try { + new CudfRegexTranspiler(replace = true).transpile(pattern) + } catch { + case e: RegexUnsupportedException => + willNotWorkOnGpu(e.getMessage) + } + } + + case _ => + willNotWorkOnGpu(s"non-literal pattern is not supported on GPU") + } + } + + override def convertToGpu( + lhs: Expression, + regexp: Expression, + rep: Expression): GpuExpression = { + if (GpuOverrides.isSupportedStringReplacePattern(expr.regexp)) { + GpuStringReplace(lhs, regexp, rep) + } else { + GpuRegExpReplace(lhs, regexp, rep) + } + } +} diff --git a/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/SparkBaseShims.scala b/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/SparkBaseShims.scala index d57ff96369f..06bbf39af1e 100644 --- a/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/SparkBaseShims.scala +++ b/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/SparkBaseShims.scala @@ -56,7 +56,7 @@ import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.execution.python._ import org.apache.spark.sql.execution.window.WindowExecBase import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.rapids.{GpuAbs, GpuAverage, GpuFileSourceScanExec, GpuStringReplace, GpuTimeSub} +import org.apache.spark.sql.rapids.{GpuAbs, GpuAverage, GpuFileSourceScanExec, GpuTimeSub} import org.apache.spark.sql.rapids.execution.{GpuBroadcastNestedLoopJoinExecBase, GpuShuffleExchangeExecBase, JoinTypeChecks, SerializeBatchDeserializeHostBuffer, SerializeConcatHostBuffersDeserializeBatch, TrampolineUtil} import org.apache.spark.sql.rapids.execution.python._ import org.apache.spark.sql.rapids.execution.python.shims.v2._ @@ -330,20 +330,11 @@ abstract class SparkBaseShims extends Spark30XShims { "RegExpReplace support for string literal input patterns", ExprChecks.projectOnly(TypeSig.STRING, TypeSig.STRING, Seq(ParamCheck("str", TypeSig.STRING, TypeSig.STRING), - ParamCheck("regex", TypeSig.lit(TypeEnum.STRING) - .withPsNote(TypeEnum.STRING, "very limited regex support"), TypeSig.STRING), + ParamCheck("regex", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING), ParamCheck("rep", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING))), - (a, conf, p, r) => new TernaryExprMeta[RegExpReplace](a, conf, p, r) { - override def tagExprForGpu(): Unit = { - if (!GpuOverrides.isSupportedStringReplacePattern(a.regexp)) { - willNotWorkOnGpu( - "Only non-null, non-empty String literals that are not regex patterns " + - "are supported by RegExpReplace on the GPU") - } - } - override def convertToGpu(lhs: Expression, regexp: Expression, - rep: Expression): GpuExpression = GpuStringReplace(lhs, regexp, rep) - }), + (a, conf, p, r) => new GpuRegExpReplaceMeta(a, conf, p, r)).disabledByDefault( + "the implementation is not 100% compatible. " + + "See the compatibility guide for more information."), GpuScalaUDFMeta.exprMeta ).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap } diff --git a/sql-plugin/src/main/301until310-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceMeta.scala b/sql-plugin/src/main/301until310-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceMeta.scala new file mode 100644 index 00000000000..fdf86447d1b --- /dev/null +++ b/sql-plugin/src/main/301until310-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceMeta.scala @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.nvidia.spark.rapids.shims.v2 + +import com.nvidia.spark.rapids._ + +import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RegExpReplace} +import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuStringReplace} +import org.apache.spark.sql.types.DataTypes +import org.apache.spark.unsafe.types.UTF8String + +class GpuRegExpReplaceMeta( + expr: RegExpReplace, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _, _]], + rule: DataFromReplacementRule) + extends TernaryExprMeta[RegExpReplace](expr, conf, parent, rule) { + + override def tagExprForGpu(): Unit = { + expr.regexp match { + case Literal(null, _) => + willNotWorkOnGpu(s"null pattern is not supported on GPU") + case Literal(s: UTF8String, DataTypes.StringType) => + val pattern = s.toString + if (pattern.isEmpty) { + willNotWorkOnGpu(s"empty pattern is not supported on GPU") + } + + if (GpuOverrides.isSupportedStringReplacePattern(expr.regexp)) { + // use GpuStringReplace + } else { + try { + new CudfRegexTranspiler(replace = true).transpile(pattern) + } catch { + case e: RegexUnsupportedException => + willNotWorkOnGpu(e.getMessage) + } + } + + case _ => + willNotWorkOnGpu(s"non-literal pattern is not supported on GPU") + } + } + + override def convertToGpu( + lhs: Expression, + regexp: Expression, + rep: Expression): GpuExpression = { + if (GpuOverrides.isSupportedStringReplacePattern(expr.regexp)) { + GpuStringReplace(lhs, regexp, rep) + } else { + GpuRegExpReplace(lhs, regexp, rep) + } + } +} diff --git a/sql-plugin/src/main/301until310-nondb/scala/com/nvidia/spark/rapids/shims/v2/SparkBaseShims.scala b/sql-plugin/src/main/301until310-nondb/scala/com/nvidia/spark/rapids/shims/v2/SparkBaseShims.scala index 42f4feddd10..563459f4dbe 100644 --- a/sql-plugin/src/main/301until310-nondb/scala/com/nvidia/spark/rapids/shims/v2/SparkBaseShims.scala +++ b/sql-plugin/src/main/301until310-nondb/scala/com/nvidia/spark/rapids/shims/v2/SparkBaseShims.scala @@ -52,7 +52,7 @@ import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNes import org.apache.spark.sql.execution.python.{AggregateInPandasExec, ArrowEvalPythonExec, FlatMapGroupsInPandasExec, MapInPandasExec, WindowInPandasExec} import org.apache.spark.sql.execution.window.WindowExecBase import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.rapids.{GpuAbs, GpuAverage, GpuFileSourceScanExec, GpuStringReplace, GpuTimeSub} +import org.apache.spark.sql.rapids.{GpuAbs, GpuAverage, GpuFileSourceScanExec, GpuTimeSub} import org.apache.spark.sql.rapids.execution.{GpuBroadcastNestedLoopJoinExecBase, GpuShuffleExchangeExecBase, JoinTypeChecks, SerializeBatchDeserializeHostBuffer, SerializeConcatHostBuffersDeserializeBatch} import org.apache.spark.sql.rapids.execution.python._ import org.apache.spark.sql.rapids.execution.python.shims.v2._ @@ -293,20 +293,11 @@ abstract class SparkBaseShims extends Spark30XShims { "RegExpReplace support for string literal input patterns", ExprChecks.projectOnly(TypeSig.STRING, TypeSig.STRING, Seq(ParamCheck("str", TypeSig.STRING, TypeSig.STRING), - ParamCheck("regex", TypeSig.lit(TypeEnum.STRING) - .withPsNote(TypeEnum.STRING, "very limited regex support"), TypeSig.STRING), + ParamCheck("regex", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING), ParamCheck("rep", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING))), - (a, conf, p, r) => new TernaryExprMeta[RegExpReplace](a, conf, p, r) { - override def tagExprForGpu(): Unit = { - if (!GpuOverrides.isSupportedStringReplacePattern(a.regexp)) { - willNotWorkOnGpu( - "Only non-null, non-empty String literals that are not regex patterns " + - "are supported by RegExpReplace on the GPU") - } - } - override def convertToGpu(lhs: Expression, regexp: Expression, - rep: Expression): GpuExpression = GpuStringReplace(lhs, regexp, rep) - }), + (a, conf, p, r) => new GpuRegExpReplaceMeta(a, conf, p, r)).disabledByDefault( + "the implementation is not 100% compatible. " + + "See the compatibility guide for more information."), GpuScalaUDFMeta.exprMeta ).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap } diff --git a/sql-plugin/src/main/311+-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala b/sql-plugin/src/main/311+-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala new file mode 100644 index 00000000000..d992c6b128d --- /dev/null +++ b/sql-plugin/src/main/311+-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.nvidia.spark.rapids.shims.v2 + +import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, QuaternaryExprMeta, RapidsConf, RapidsMeta, RegexUnsupportedException} + +import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RegExpReplace} +import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuStringReplace} +import org.apache.spark.sql.types.DataTypes +import org.apache.spark.unsafe.types.UTF8String + +class GpuRegExpReplaceMeta( + expr: RegExpReplace, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _, _]], + rule: DataFromReplacementRule) + extends QuaternaryExprMeta[RegExpReplace](expr, conf, parent, rule) { + + override def tagExprForGpu(): Unit = { + expr.regexp match { + case Literal(null, _) => + willNotWorkOnGpu(s"null pattern is not supported on GPU") + case Literal(s: UTF8String, DataTypes.StringType) => + val pattern = s.toString + if (pattern.isEmpty) { + willNotWorkOnGpu(s"empty pattern is not supported on GPU") + } + + if (GpuOverrides.isSupportedStringReplacePattern(expr.regexp)) { + // use GpuStringReplace + } else { + try { + new CudfRegexTranspiler(replace = true).transpile(pattern) + } catch { + case e: RegexUnsupportedException => + willNotWorkOnGpu(e.getMessage) + } + } + + case _ => + willNotWorkOnGpu(s"non-literal pattern is not supported on GPU") + } + + GpuOverrides.extractLit(expr.pos).foreach { lit => + if (lit.value.asInstanceOf[Int] != 1) { + willNotWorkOnGpu("only a search starting position of 1 is supported") + } + } + } + + override def convertToGpu( + lhs: Expression, + regexp: Expression, + rep: Expression, + pos: Expression): GpuExpression = { + // ignore the pos expression which must be a literal 1 after tagging check + require(childExprs.length == 4, + s"Unexpected child count for RegExpReplace: ${childExprs.length}") + val Seq(subject, regexp, rep) = childExprs.take(3).map(_.convertToGpu()) + if (GpuOverrides.isSupportedStringReplacePattern(expr.regexp)) { + GpuStringReplace(subject, regexp, rep) + } else { + GpuRegExpReplace(subject, regexp, rep) + } + } +} \ No newline at end of file diff --git a/sql-plugin/src/main/311until320-nondb/scala/com/nvidia/spark/rapids/shims/v2/SparkBaseShims.scala b/sql-plugin/src/main/311until320-nondb/scala/com/nvidia/spark/rapids/shims/v2/SparkBaseShims.scala index 94235545e3a..9c2c235e57f 100644 --- a/sql-plugin/src/main/311until320-nondb/scala/com/nvidia/spark/rapids/shims/v2/SparkBaseShims.scala +++ b/sql-plugin/src/main/311until320-nondb/scala/com/nvidia/spark/rapids/shims/v2/SparkBaseShims.scala @@ -226,33 +226,14 @@ abstract class SparkBaseShims extends Spark31XShims { "RegExpReplace support for string literal input patterns", ExprChecks.projectOnly(TypeSig.STRING, TypeSig.STRING, Seq(ParamCheck("str", TypeSig.STRING, TypeSig.STRING), - ParamCheck("regex", TypeSig.lit(TypeEnum.STRING) - .withPsNote(TypeEnum.STRING, "very limited regex support"), TypeSig.STRING), + ParamCheck("regex", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING), ParamCheck("rep", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING), ParamCheck("pos", TypeSig.lit(TypeEnum.INT) .withPsNote(TypeEnum.INT, "only a value of 1 is supported"), TypeSig.lit(TypeEnum.INT)))), - (a, conf, p, r) => new ExprMeta[RegExpReplace](a, conf, p, r) { - override def tagExprForGpu(): Unit = { - if (!GpuOverrides.isSupportedStringReplacePattern(a.regexp)) { - willNotWorkOnGpu( - "Only non-null, non-empty String literals that are not regex patterns " + - "are supported by RegExpReplace on the GPU") - } - GpuOverrides.extractLit(a.pos).foreach { lit => - if (lit.value.asInstanceOf[Int] != 1) { - willNotWorkOnGpu("Only a search starting position of 1 is supported") - } - } - } - override def convertToGpu(): GpuExpression = { - // ignore the pos expression which must be a literal 1 after tagging check - require(childExprs.length == 4, - s"Unexpected child count for RegExpReplace: ${childExprs.length}") - val Seq(subject, regexp, rep) = childExprs.take(3).map(_.convertToGpu()) - GpuStringReplace(subject, regexp, rep) - } - }), + (a, conf, p, r) => new GpuRegExpReplaceMeta(a, conf, p, r)).disabledByDefault( + "the implementation is not 100% compatible. " + + "See the compatibility guide for more information."), // Spark 3.1.1-specific LEAD expression, using custom OffsetWindowFunctionMeta. GpuOverrides.expr[Lead]( "Window function that returns N entries ahead of this one", diff --git a/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala b/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala new file mode 100644 index 00000000000..d992c6b128d --- /dev/null +++ b/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.nvidia.spark.rapids.shims.v2 + +import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, QuaternaryExprMeta, RapidsConf, RapidsMeta, RegexUnsupportedException} + +import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RegExpReplace} +import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuStringReplace} +import org.apache.spark.sql.types.DataTypes +import org.apache.spark.unsafe.types.UTF8String + +class GpuRegExpReplaceMeta( + expr: RegExpReplace, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _, _]], + rule: DataFromReplacementRule) + extends QuaternaryExprMeta[RegExpReplace](expr, conf, parent, rule) { + + override def tagExprForGpu(): Unit = { + expr.regexp match { + case Literal(null, _) => + willNotWorkOnGpu(s"null pattern is not supported on GPU") + case Literal(s: UTF8String, DataTypes.StringType) => + val pattern = s.toString + if (pattern.isEmpty) { + willNotWorkOnGpu(s"empty pattern is not supported on GPU") + } + + if (GpuOverrides.isSupportedStringReplacePattern(expr.regexp)) { + // use GpuStringReplace + } else { + try { + new CudfRegexTranspiler(replace = true).transpile(pattern) + } catch { + case e: RegexUnsupportedException => + willNotWorkOnGpu(e.getMessage) + } + } + + case _ => + willNotWorkOnGpu(s"non-literal pattern is not supported on GPU") + } + + GpuOverrides.extractLit(expr.pos).foreach { lit => + if (lit.value.asInstanceOf[Int] != 1) { + willNotWorkOnGpu("only a search starting position of 1 is supported") + } + } + } + + override def convertToGpu( + lhs: Expression, + regexp: Expression, + rep: Expression, + pos: Expression): GpuExpression = { + // ignore the pos expression which must be a literal 1 after tagging check + require(childExprs.length == 4, + s"Unexpected child count for RegExpReplace: ${childExprs.length}") + val Seq(subject, regexp, rep) = childExprs.take(3).map(_.convertToGpu()) + if (GpuOverrides.isSupportedStringReplacePattern(expr.regexp)) { + GpuStringReplace(subject, regexp, rep) + } else { + GpuRegExpReplace(subject, regexp, rep) + } + } +} \ No newline at end of file diff --git a/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/SparkBaseShims.scala b/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/SparkBaseShims.scala index 9899d2e0fbd..02ab421d7f4 100644 --- a/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/SparkBaseShims.scala +++ b/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/SparkBaseShims.scala @@ -228,33 +228,14 @@ abstract class SparkBaseShims extends Spark30XShims { "RegExpReplace support for string literal input patterns", ExprChecks.projectOnly(TypeSig.STRING, TypeSig.STRING, Seq(ParamCheck("str", TypeSig.STRING, TypeSig.STRING), - ParamCheck("regex", TypeSig.lit(TypeEnum.STRING) - .withPsNote(TypeEnum.STRING, "very limited regex support"), TypeSig.STRING), + ParamCheck("regex", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING), ParamCheck("rep", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING), ParamCheck("pos", TypeSig.lit(TypeEnum.INT) .withPsNote(TypeEnum.INT, "only a value of 1 is supported"), TypeSig.lit(TypeEnum.INT)))), - (a, conf, p, r) => new ExprMeta[RegExpReplace](a, conf, p, r) { - override def tagExprForGpu(): Unit = { - if (!GpuOverrides.isSupportedStringReplacePattern(a.regexp)) { - willNotWorkOnGpu( - "Only non-null, non-empty String literals that are not regex patterns " + - "are supported by RegExpReplace on the GPU") - } - GpuOverrides.extractLit(a.pos).foreach { lit => - if (lit.value.asInstanceOf[Int] != 1) { - willNotWorkOnGpu("Only a search starting position of 1 is supported") - } - } - } - override def convertToGpu(): GpuExpression = { - // ignore the pos expression which must be a literal 1 after tagging check - require(childExprs.length == 4, - s"Unexpected child count for RegExpReplace: ${childExprs.length}") - val Seq(subject, regexp, rep) = childExprs.take(3).map(_.convertToGpu()) - GpuStringReplace(subject, regexp, rep) - } - }), + (a, conf, p, r) => new GpuRegExpReplaceMeta(a, conf, p, r)).disabledByDefault( + "the implementation is not 100% compatible. " + + "See the compatibility guide for more information."), // Spark 3.1.1-specific LEAD expression, using custom OffsetWindowFunctionMeta. GpuOverrides.expr[Lead]( "Window function that returns N entries ahead of this one", diff --git a/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/v2/Spark32XShims.scala b/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/v2/Spark32XShims.scala index 98266dfd396..b408b3adc06 100644 --- a/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/v2/Spark32XShims.scala +++ b/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/v2/Spark32XShims.scala @@ -61,7 +61,7 @@ import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.execution.python._ import org.apache.spark.sql.execution.window.WindowExecBase import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} -import org.apache.spark.sql.rapids.{GpuAbs, GpuAnsi, GpuAverage, GpuElementAt, GpuFileSourceScanExec, GpuGetArrayItem, GpuGetArrayItemMeta, GpuGetMapValue, GpuGetMapValueMeta, GpuStringReplace} +import org.apache.spark.sql.rapids.{GpuAbs, GpuAnsi, GpuAverage, GpuElementAt, GpuFileSourceScanExec, GpuGetArrayItem, GpuGetArrayItemMeta, GpuGetMapValue, GpuGetMapValueMeta} import org.apache.spark.sql.rapids.execution._ import org.apache.spark.sql.rapids.execution.python._ import org.apache.spark.sql.rapids.execution.python.shims.v2.GpuFlatMapGroupsInPandasExecMeta @@ -368,33 +368,12 @@ trait Spark32XShims extends SparkShims { "RegExpReplace support for string literal input patterns", ExprChecks.projectOnly(TypeSig.STRING, TypeSig.STRING, Seq(ParamCheck("str", TypeSig.STRING, TypeSig.STRING), - ParamCheck("regex", TypeSig.lit(TypeEnum.STRING) - .withPsNote(TypeEnum.STRING, "very limited regex support"), TypeSig.STRING), + ParamCheck("regex", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING), ParamCheck("rep", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING), ParamCheck("pos", TypeSig.lit(TypeEnum.INT) .withPsNote(TypeEnum.INT, "only a value of 1 is supported"), TypeSig.lit(TypeEnum.INT)))), - (a, conf, p, r) => new ExprMeta[RegExpReplace](a, conf, p, r) { - override def tagExprForGpu(): Unit = { - if (!GpuOverrides.isSupportedStringReplacePattern(a.regexp)) { - willNotWorkOnGpu( - "Only non-null, non-empty String literals that are not regex patterns " + - "are supported by RegExpReplace on the GPU") - } - GpuOverrides.extractLit(a.pos).foreach { lit => - if (lit.value.asInstanceOf[Int] != 1) { - willNotWorkOnGpu("Only a search starting position of 1 is supported") - } - } - } - override def convertToGpu(): GpuExpression = { - // ignore the pos expression which must be a literal 1 after tagging check - require(childExprs.length == 4, - s"Unexpected child count for RegExpReplace: ${childExprs.length}") - val Seq(subject, regexp, rep) = childExprs.take(3).map(_.convertToGpu()) - GpuStringReplace(subject, regexp, rep) - } - }), + (a, conf, p, r) => new GpuRegExpReplaceMeta(a, conf, p, r)), // Spark 3.2.0-specific LEAD expression, using custom OffsetWindowFunctionMeta. GpuOverrides.expr[Lead]( "Window function that returns N entries ahead of this one", diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 2081a6f436a..45a69c9e8ae 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -2950,8 +2950,8 @@ object GpuOverrides extends Logging { ("str", TypeSig.STRING, TypeSig.STRING), ("regexp", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING)), (a, conf, p, r) => new GpuRLikeMeta(a, conf, p, r)).disabledByDefault( - "The GPU implementation of rlike is not " + - "compatible with Apache Spark. See the compatibility guide for more information."), + "the implementation is not 100% compatible. " + + "See the compatibility guide for more information."), expr[Length]( "String character length or binary byte length", ExprChecks.unaryProject(TypeSig.INT, TypeSig.INT, diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala index 751b1013664..56956b83a59 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala @@ -20,7 +20,7 @@ import java.time.ZoneId import scala.collection.mutable -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BinaryExpression, ComplexTypeMergingExpression, Expression, String2TrimExpression, TernaryExpression, UnaryExpression, WindowExpression, WindowFunction} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BinaryExpression, ComplexTypeMergingExpression, Expression, QuaternaryExpression, String2TrimExpression, TernaryExpression, UnaryExpression, WindowExpression, WindowFunction} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, ImperativeAggregate, TypedImperativeAggregate} import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.trees.TreeNodeTag @@ -1246,6 +1246,25 @@ abstract class TernaryExprMeta[INPUT <: TernaryExpression]( val2: Expression): GpuExpression } +/** + * Base class for metadata around `QuaternaryExpression`. + */ +abstract class QuaternaryExprMeta[INPUT <: QuaternaryExpression]( + expr: INPUT, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _, _]], + rule: DataFromReplacementRule) + extends ExprMeta[INPUT](expr, conf, parent, rule) { + + override final def convertToGpu(): GpuExpression = { + val Seq(child0, child1, child2, child3) = childExprs.map(_.convertToGpu()) + convertToGpu(child0, child1, child2, child3) + } + + def convertToGpu(val0: Expression, val1: Expression, + val2: Expression, val3: Expression): GpuExpression +} + abstract class String2TrimExpressionMeta[INPUT <: String2TrimExpression]( expr: INPUT, conf: RapidsConf, diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala index 6781dbfffa5..c60e8db3524 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala @@ -407,11 +407,22 @@ class RegexParser(pattern: String) { /** * Transpile Java/Spark regular expression to a format that cuDF supports, or throw an exception * if this is not possible. + * + * @param replace True if performing a replacement (regexp_replace), false + * if matching only (rlike) */ -class CudfRegexTranspiler { +class CudfRegexTranspiler(replace: Boolean) { - val nothingToRepeat = "nothing to repeat" + // cuDF throws a "nothing to repeat" exception for many of the edge cases that are + // rejected by the transpiler + private val nothingToRepeat = "nothing to repeat" + /** + * Parse Java regular expression and translate into cuDF regular expression. + * + * @param pattern Regular expression that is valid in Java's engine + * @return Regular expression in cuDF format + */ def transpile(pattern: String): String = { // parse the source regular expression val regex = new RegexParser(pattern).parse() @@ -429,6 +440,11 @@ class CudfRegexTranspiler { case '.' => // workaround for https://github.com/rapidsai/cudf/issues/9619 RegexCharacterClass(negated = true, ListBuffer(RegexChar('\r'), RegexChar('\n'))) + case '^' | '$' if replace => + // this is a bit extreme and it would be good to replace with finer-grained + // rules + throw new RegexUnsupportedException("regexp_replace on GPU does not support ^ or $") + case _ => regex } @@ -463,10 +479,9 @@ class CudfRegexTranspiler { // - "[a-b[c-d]]" is supported by Java but not cuDF throw new RegexUnsupportedException("nested character classes are not supported") case _ => - } val components: Seq[RegexCharacterClassComponent] = characters - .map(ch => rewrite(ch).asInstanceOf[RegexCharacterClassComponent]) + .map(x => rewrite(x).asInstanceOf[RegexCharacterClassComponent]) RegexCharacterClass(negated, ListBuffer(components: _*)) case RegexSequence(parts) => @@ -486,9 +501,20 @@ class CudfRegexTranspiler { // falling back to CPU throw new RegexUnsupportedException(nothingToRepeat) } + if (replace && parts.length == 1 && (isRegexChar(parts.head, '^') + || isRegexChar(parts.head, '$'))) { + throw new RegexUnsupportedException("regexp_replace on GPU does not support ^ or $") + } RegexSequence(parts.map(rewrite)) case RegexRepetition(base, quantifier) => (base, quantifier) match { + case (_, SimpleQuantifier(ch)) if replace && "?*".contains(ch) => + // example: pattern " ?", input "] b[", replace with "X": + // java: X]XXbX[X + // cuDF: XXXX] b[ + throw new RegexUnsupportedException( + "regexp_replace on GPU does not support repetition with ? or *") + case (RegexEscaped(_), _) => // example: "\B?" throw new RegexUnsupportedException(nothingToRepeat) @@ -503,6 +529,7 @@ class CudfRegexTranspiler { case _ => RegexRepetition(rewrite(base), quantifier) + } case RegexChoice(l, r) => diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index bc413b48e58..c2b584507c6 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -755,7 +755,7 @@ class GpuRLikeMeta( case Literal(str: UTF8String, _) => try { // verify that we support this regex and can transpile it to cuDF format - new CudfRegexTranspiler().transpile(str.toString) + new CudfRegexTranspiler(replace = false).transpile(str.toString) } catch { case e: RegexUnsupportedException => willNotWorkOnGpu(e.getMessage) @@ -790,7 +790,7 @@ case class GpuRLike(left: Expression, right: Expression) "Cannot have an invalid scalar value as right side operand in RLike") } try { - val cudfRegex = new CudfRegexTranspiler().transpile(pattern) + val cudfRegex = new CudfRegexTranspiler(replace = false).transpile(pattern) lhs.getBase.containsRe(cudfRegex) } catch { case _: RegexUnsupportedException => @@ -810,6 +810,101 @@ case class GpuRLike(left: Expression, right: Expression) override def dataType: DataType = BooleanType } +object GpuRegExpReplaceMeta { + def isSupportedRegExpReplacePattern(pattern: String): Boolean = { + try { + new CudfRegexTranspiler(replace = true).transpile(pattern) + true + } catch { + case _: RegexUnsupportedException => false + } + } +} + +case class GpuRegExpReplace( + srcExpr: Expression, + searchExpr: Expression, + replaceExpr: Expression) + extends GpuTernaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = srcExpr.dataType + + override def inputTypes: Seq[DataType] = Seq(StringType, StringType, StringType) + + override def first: Expression = srcExpr + override def second: Expression = searchExpr + override def third: Expression = replaceExpr + + def this(srcExpr: Expression, searchExpr: Expression) = { + this(srcExpr, searchExpr, GpuLiteral("", StringType)) + } + + override def doColumnar( + strExpr: GpuColumnVector, + searchExpr: GpuColumnVector, + replaceExpr: GpuColumnVector): ColumnVector = + throw new UnsupportedOperationException(s"Cannot columnar evaluate expression: $this") + + override def doColumnar( + strExpr: GpuScalar, + searchExpr: GpuColumnVector, + replaceExpr: GpuColumnVector): ColumnVector = + throw new UnsupportedOperationException(s"Cannot columnar evaluate expression: $this") + + override def doColumnar( + strExpr: GpuScalar, + searchExpr: GpuScalar, + replaceExpr: GpuColumnVector): ColumnVector = + throw new UnsupportedOperationException(s"Cannot columnar evaluate expression: $this") + + override def doColumnar( + strExpr: GpuScalar, + searchExpr: GpuColumnVector, + replaceExpr: GpuScalar): ColumnVector = + throw new UnsupportedOperationException(s"Cannot columnar evaluate expression: $this") + + override def doColumnar( + strExpr: GpuColumnVector, + searchExpr: GpuScalar, + replaceExpr: GpuColumnVector): ColumnVector = + throw new UnsupportedOperationException(s"Cannot columnar evaluate expression: $this") + + override def doColumnar( + strExpr: GpuColumnVector, + searchExpr: GpuScalar, + replaceExpr: GpuScalar): ColumnVector = { + + searchExpr.getValue match { + case null => + // Return original string if search string is null + strExpr.getBase.asStrings() + case pattern: UTF8String => + try { + val cudfRegex = new CudfRegexTranspiler(replace = true).transpile(pattern.toString) + strExpr.getBase.replaceRegex(cudfRegex, replaceExpr.getBase) + } catch { + case _: RegexUnsupportedException => + throw new IllegalStateException("Really should not be here, " + + "regular expression should have been verified during tagging") + } + + } + } + + override def doColumnar(numRows: Int, val0: GpuScalar, val1: GpuScalar, + val2: GpuScalar): ColumnVector = { + withResource(GpuColumnVector.from(val0, numRows, srcExpr.dataType)) { val0Col => + doColumnar(val0Col, val1, val2) + } + } + + override def doColumnar( + strExpr: GpuColumnVector, + searchExpr: GpuColumnVector, + replaceExpr: GpuScalar): ColumnVector = + throw new UnsupportedOperationException(s"Cannot columnar evaluate expression: $this") +} + class SubstringIndexMeta( expr: SubstringIndex, override val conf: RapidsConf, diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala index fc283fd516c..9bd859f6562 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala @@ -23,8 +23,50 @@ import scala.util.{Random, Try} import ai.rapids.cudf.{ColumnVector, CudfException} import org.scalatest.FunSuite +import org.apache.spark.sql.types.DataTypes + class RegularExpressionTranspilerSuite extends FunSuite with Arm { + test("transpiler detects invalid cuDF patterns") { + // The purpose of this test is to document some examples of valid Java regular expressions + // that fail to compile in cuDF and to check that the transpiler detects these correctly. + // Many (but not all) of the patterns here are odd edge cases found during testing with + // random inputs. + val cudfInvalidPatterns = Seq( + "a*+", + "\t+|a", + "(\t+|a)Dc$1", + "(?d)" + ) + // data is not relevant because we are checking for compilation errors + val inputs = Seq("a") + for (pattern <- cudfInvalidPatterns) { + // check that this is valid in Java + Pattern.compile(pattern) + Seq(true, false).foreach { replace => + try { + if (replace) { + gpuReplace(pattern, inputs) + } else { + gpuContains(pattern, inputs) + } + fail(s"cuDF unexpectedly compiled expression: $pattern") + } catch { + case e: CudfException => + // expected, now make sure that the transpiler can detect this + try { + transpile(pattern, replace) + fail( + s"transpiler failed to detect invalid cuDF pattern (replace=$replace): $pattern", e) + } catch { + case _: RegexUnsupportedException => + // expected + } + } + } + } + } + test("cuDF does not support choice with nothing to repeat") { val patterns = Seq("b+|^\t") patterns.foreach(pattern => @@ -94,18 +136,18 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { // see https://github.com/rapidsai/cudf/issues/9620 val pattern = "2$" // this matches "2" but not "2\n" on the GPU - assertCpuGpuContainsMatches(Seq(pattern), Seq("2", "2\n", "2\r", "2\r\n")) + assertCpuGpuMatchesRegexpFind(Seq(pattern), Seq("2", "2\n", "2\r", "2\r\n")) } test("dot matches CR on GPU but not on CPU") { // see https://github.com/rapidsai/cudf/issues/9619 val pattern = "1." - assertCpuGpuContainsMatches(Seq(pattern), Seq("1\r2", "1\n2", "1\r\n2")) + assertCpuGpuMatchesRegexpFind(Seq(pattern), Seq("1\r2", "1\n2", "1\r\n2")) } ignore("known issue - octal digit") { val pattern = "a\\141|.$" // using hex works fine e.g. "a\\x61|.$" - assertCpuGpuContainsMatches(Seq(pattern), Seq("] b[")) + assertCpuGpuMatchesRegexpFind(Seq(pattern), Seq("] b[")) } test("character class with ranges") { @@ -121,15 +163,11 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { test("transpile character class unescaped range symbol") { val patterns = Seq("a[-b]", "a[+-]", "a[-+]", "a[-]", "a[^-]") val expected = Seq(raw"a[\-b]", raw"a[+\-]", raw"a[\-+]", raw"a[\-]", raw"a[^\-]") - val transpiler = new CudfRegexTranspiler() + val transpiler = new CudfRegexTranspiler(replace=false) val transpiled = patterns.map(transpiler.transpile) assert(transpiled === expected) } - test("transpile dot") { - assert(new CudfRegexTranspiler().transpile(".+") === "[^\r\n]+") - } - test("transpile complex regex 1") { val VALID_FLOAT_REGEX = "^" + // start of line @@ -167,37 +205,52 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { test("compare CPU and GPU: character range including unescaped + and -") { val patterns = Seq("a[-]+", "a[a-b-]+", "a[-a-b]", "a[-+]", "a[+-]") val inputs = Seq("a+", "a-", "a", "a-+", "a[a-b-]") - assertCpuGpuContainsMatches(patterns, inputs) + assertCpuGpuMatchesRegexpFind(patterns, inputs) } test("compare CPU and GPU: character range including escaped + and -") { val patterns = Seq(raw"a[\-\+]", raw"a[\+\-]", raw"a[a-b\-]") val inputs = Seq("a+", "a-", "a", "a-+", "a[a-b-]") - assertCpuGpuContainsMatches(patterns, inputs) + assertCpuGpuMatchesRegexpFind(patterns, inputs) } test("compare CPU and GPU: hex") { val patterns = Seq(raw"\x61") val inputs = Seq("a", "b") - assertCpuGpuContainsMatches(patterns, inputs) + assertCpuGpuMatchesRegexpFind(patterns, inputs) } test("compare CPU and GPU: octal") { val patterns = Seq("\\\\141") val inputs = Seq("a", "b") - assertCpuGpuContainsMatches(patterns, inputs) + assertCpuGpuMatchesRegexpFind(patterns, inputs) } - test("compare CPU and GPU: fuzz test with limited chars") { + private val REGEXP_LIMITED_CHARS = "|()[]{},.^$*+?abc123x\\ \tBsdwSDW" + + test("compare CPU and GPU: regexp find fuzz test with limited chars") { + // testing with this limited set of characters finds issues much + // faster than using the full ASCII set + // CR and LF has been excluded due to known issues + doFuzzTest(Some(REGEXP_LIMITED_CHARS), replace = false) + } + + test("compare CPU and GPU: regexp replace simple regular expressions") { + val inputs = Seq("a", "b", "c") + val patterns = Seq("a|b") + assertCpuGpuMatchesRegexpReplace(patterns, inputs) + } + + test("compare CPU and GPU: regexp replace fuzz test with limited chars") { // testing with this limited set of characters finds issues much // faster than using the full ASCII set // LF has been excluded due to known issues - doFuzzTest(Some("|()[]{},.^$*+?abc123x\\ \r\tB")) + doFuzzTest(Some(REGEXP_LIMITED_CHARS), replace = true) } - test("compare CPU and GPU: fuzz test printable ASCII chars plus CR and TAB") { + test("compare CPU and GPU: regexp find fuzz test printable ASCII chars plus CR and TAB") { // CR and LF has been excluded due to known issues - doFuzzTest(Some((0x20 to 0x7F).map(_.toChar) + "\r\t")) + doFuzzTest(Some((0x20 to 0x7F).map(_.toChar) + "\r\t"), replace = false) } test("compare CPU and GPU: fuzz test ASCII chars") { @@ -205,15 +258,15 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { val chars = (0x00 to 0x7F) .map(_.toChar) .filterNot(_ == '\n') - doFuzzTest(Some(chars.mkString)) + doFuzzTest(Some(chars.mkString), replace = true) } - ignore("compare CPU and GPU: fuzz test all chars") { + ignore("compare CPU and GPU: regexp find fuzz test all chars") { // this test cannot be enabled until we support CR and LF - doFuzzTest(None) + doFuzzTest(None, replace = false) } - private def doFuzzTest(validChars: Option[String]) { + private def doFuzzTest(validChars: Option[String], replace: Boolean) { val r = new EnhancedRandom(new Random(seed = 0L), options = FuzzerOptions(validChars, maxStringLen = 12)) @@ -226,12 +279,16 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { val patterns = ListBuffer[String]() while (patterns.length < 5000) { val pattern = r.nextString() - if (Try(Pattern.compile(pattern)).isSuccess && Try(transpile(pattern)).isSuccess) { + if (Try(Pattern.compile(pattern)).isSuccess && Try(transpile(pattern, replace)).isSuccess) { patterns += pattern } } - assertCpuGpuContainsMatches(patterns, data) + if (replace) { + assertCpuGpuMatchesRegexpReplace(patterns, data) + } else { + assertCpuGpuMatchesRegexpFind(patterns, data) + } } private def removeTrailingNewlines(input: String): String = { @@ -242,10 +299,10 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { s } - private def assertCpuGpuContainsMatches(javaPatterns: Seq[String], input: Seq[String]) = { + private def assertCpuGpuMatchesRegexpFind(javaPatterns: Seq[String], input: Seq[String]) = { for (javaPattern <- javaPatterns) { val cpu = cpuContains(javaPattern, input) - val cudfPattern = new CudfRegexTranspiler().transpile(javaPattern) + val cudfPattern = new CudfRegexTranspiler(replace = false).transpile(javaPattern) val gpu = try { gpuContains(cudfPattern, input) } catch { @@ -263,6 +320,29 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { } } + private def assertCpuGpuMatchesRegexpReplace( + javaPatterns: Seq[String], + input: Seq[String]) = { + for (javaPattern <- javaPatterns) { + val cpu = cpuReplace(javaPattern, input) + val cudfPattern = new CudfRegexTranspiler(replace = true).transpile(javaPattern) + val gpu = try { + gpuReplace(cudfPattern, input) + } catch { + case e: CudfException => + fail(s"cuDF failed to compile pattern: $cudfPattern", e) + } + for (i <- input.indices) { + if (cpu(i) != gpu(i)) { + fail(s"javaPattern=${toReadableString(javaPattern)}, " + + s"cudfPattern=${toReadableString(cudfPattern)}, " + + s"input='${toReadableString(input(i))}', " + + s"cpu=${cpu(i)}, gpu=${gpu(i)}") + } + } + } + } + /** cuDF containsRe helper */ private def gpuContains(cudfPattern: String, input: Seq[String]): Array[Boolean] = { val result = new Array[Boolean](input.length) @@ -276,6 +356,23 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { result } + private val REPLACE_STRING = "_REPLACE_" + + /** cuDF replaceRe helper */ + private def gpuReplace(cudfPattern: String, input: Seq[String]): Array[String] = { + val result = new Array[String](input.length) + withResource(ColumnVector.fromStrings(input: _*)) { cv => + withResource(GpuScalar.from(REPLACE_STRING, DataTypes.StringType)) { replace => + withResource(cv.replaceRegex(cudfPattern, replace)) { c => + withResource(c.copyToHost()) { hv => + result.indices.foreach(i => result(i) = new String(hv.getUTF8(i))) + } + } + } + } + result + } + private def toReadableString(x: String): String = { x.map { case '\r' => "\\r" @@ -290,18 +387,23 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { input.map(s => p.matcher(s).find(0)).toArray } + private def cpuReplace(pattern: String, input: Seq[String]): Array[String] = { + val p = Pattern.compile(pattern) + input.map(s => p.matcher(s).replaceAll(REPLACE_STRING)).toArray + } + private def doTranspileTest(pattern: String, expected: String) { - val transpiled: String = transpile(pattern) + val transpiled: String = transpile(pattern, replace = false) assert(transpiled === expected) } - private def transpile(pattern: String): String = { - new CudfRegexTranspiler().transpile(pattern) + private def transpile(pattern: String, replace: Boolean): String = { + new CudfRegexTranspiler(replace).transpile(pattern) } private def assertUnsupported(pattern: String, message: String): Unit = { val e = intercept[RegexUnsupportedException] { - transpile(pattern) + transpile(pattern, replace = false) } assert(e.getMessage.startsWith(message)) } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/StringFallbackSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/StringFallbackSuite.scala index 947e6c92484..6ba3deed5c6 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/StringFallbackSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/StringFallbackSuite.scala @@ -15,20 +15,25 @@ */ package com.nvidia.spark.rapids +import org.apache.spark.SparkConf + class StringFallbackSuite extends SparkQueryCompareTestSuite { + + private val conf = new SparkConf().set("spark.rapids.sql.expression.RegExpReplace", "true") + testGpuFallback( "String regexp_replace replace str columnar fall back", "RegExpReplace", nullableStringsFromCsv, execsAllowedNonGpu = Seq("ProjectExec", "Alias", - "RegExpReplace", "AttributeReference", "Literal")) { + "RegExpReplace", "AttributeReference", "Literal"), conf = conf) { frame => frame.selectExpr("regexp_replace(strings,'a',strings)") } testGpuFallback("String regexp_replace null cpu fall back", "RegExpReplace", nullableStringsFromCsv, execsAllowedNonGpu = Seq("ProjectExec", "Alias", - "RegExpReplace", "AttributeReference", "Literal")) { + "RegExpReplace", "AttributeReference", "Literal"), conf = conf) { frame => { // this test is only valid in Spark 3.0.x because the expression is NullIntolerant // since Spark 3.1.0 and gets replaced with a null literal instead @@ -46,49 +51,43 @@ class StringFallbackSuite extends SparkQueryCompareTestSuite { testGpuFallback("String regexp_replace input empty cpu fall back", "RegExpReplace", nullableStringsFromCsv, execsAllowedNonGpu = Seq("ProjectExec", "Alias", - "RegExpReplace", "AttributeReference", "Literal")) { + "RegExpReplace", "AttributeReference", "Literal"), conf = conf) { frame => frame.selectExpr("regexp_replace(strings,'','D')") } testGpuFallback("String regexp_replace regex 1 cpu fall back", "RegExpReplace", nullableStringsFromCsv, execsAllowedNonGpu = Seq("ProjectExec", "Alias", - "RegExpReplace", "AttributeReference", "Literal")) { + "RegExpReplace", "AttributeReference", "Literal"), conf = conf) { frame => frame.selectExpr("regexp_replace(strings,'.*','D')") } - testGpuFallback("String regexp_replace regex 2 cpu fall back", - "RegExpReplace", - nullableStringsFromCsv, execsAllowedNonGpu = Seq("ProjectExec", "Alias", - "RegExpReplace", "AttributeReference", "Literal")) { + testSparkResultsAreEqual("String regexp_replace regex 2", + nullableStringsFromCsv, conf = conf) { frame => frame.selectExpr("regexp_replace(strings,'[a-z]+','D')") } testGpuFallback("String regexp_replace regex 3 cpu fall back", "RegExpReplace", nullableStringsFromCsv, execsAllowedNonGpu = Seq("ProjectExec", "Alias", - "RegExpReplace", "AttributeReference", "Literal")) { + "RegExpReplace", "AttributeReference", "Literal"), conf = conf) { frame => frame.selectExpr("regexp_replace(strings,'foo$','D')") } testGpuFallback("String regexp_replace regex 4 cpu fall back", "RegExpReplace", nullableStringsFromCsv, execsAllowedNonGpu = Seq("ProjectExec", "Alias", - "RegExpReplace", "AttributeReference", "Literal")) { + "RegExpReplace", "AttributeReference", "Literal"), conf = conf) { frame => frame.selectExpr("regexp_replace(strings,'^foo','D')") } - testGpuFallback("String regexp_replace regex 5 cpu fall back", - "RegExpReplace", - nullableStringsFromCsv, execsAllowedNonGpu = Seq("ProjectExec", "Alias", - "RegExpReplace", "AttributeReference", "Literal")) { + testSparkResultsAreEqual("String regexp_replace regex 5", + nullableStringsFromCsv, conf = conf) { frame => frame.selectExpr("regexp_replace(strings,'(foo)','D')") } - testGpuFallback("String regexp_replace regex 6 cpu fall back", - "RegExpReplace", - nullableStringsFromCsv, execsAllowedNonGpu = Seq("ProjectExec", "Alias", - "RegExpReplace", "AttributeReference", "Literal")) { + testSparkResultsAreEqual("String regexp_replace regex 6", + nullableStringsFromCsv, conf = conf) { frame => frame.selectExpr("regexp_replace(strings,'\\(foo\\)','D')") } }