Skip to content

Commit

Permalink
Improve the diagnostics for 'conv' fallback explain (NVIDIA#11076)
Browse files Browse the repository at this point in the history
* Improve the diagnostics for 'conv' fallback explain

Signed-off-by: Jihoon Son <ghoonson@gmail.com>

* don't use nil

Signed-off-by: Jihoon Son <ghoonson@gmail.com>

* the bases should not be an empty string in the error message when the user input is not

Signed-off-by: Jihoon Son <ghoonson@gmail.com>

* more user-friendly message

* Update sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala

Co-authored-by: Gera Shegalov <gshegalov@nvidia.com>

---------

Signed-off-by: Jihoon Son <ghoonson@gmail.com>
Co-authored-by: Gera Shegalov <gshegalov@nvidia.com>
  • Loading branch information
jihoonson and gerashegalov authored Jun 25, 2024
1 parent b3b5b5e commit 6455396
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 3 deletions.
16 changes: 16 additions & 0 deletions integration_tests/src/main/python/string_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,22 @@ def test_conv_dec_to_from_hex(from_base, to_base, pattern):
conf={'spark.rapids.sql.expression.Conv': True}
)

@pytest.mark.parametrize('from_base,to_base,expected_err_msg_prefix',
[
pytest.param(10, 15, '15 is not a supported target radix', id='to_base_unsupported'),
pytest.param(11, 16, '11 is not a supported source radix', id='from_base_unsupported'),
pytest.param(9, 17, 'both 9 and 17 are not a supported radix', id='both_base_unsupported')
])
def test_conv_unsupported_base(from_base, to_base, expected_err_msg_prefix):
def do_conv(spark):
gen = StringGen()
df = unary_op_df(spark, gen).select('a', f.conv(f.col('a'), from_base, to_base))
explain_str = spark.sparkContext._jvm.com.nvidia.spark.rapids.ExplainPlan.explainPotentialGpuPlan(df._jdf, "ALL")
unsupported_base_str = f'{expected_err_msg_prefix}, only literal 10 or 16 are supported for source and target radixes'
assert unsupported_base_str in explain_str

with_cpu_session(do_conv)

format_number_gens = integral_gens + [DecimalGen(precision=7, scale=7), DecimalGen(precision=18, scale=0),
DecimalGen(precision=18, scale=3), DecimalGen(precision=36, scale=5),
DecimalGen(precision=36, scale=-5), DecimalGen(precision=38, scale=10),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2084,11 +2084,25 @@ class GpuConvMeta(
override def tagExprForGpu(): Unit = {
val fromBaseLit = GpuOverrides.extractLit(expr.fromBaseExpr)
val toBaseLit = GpuOverrides.extractLit(expr.toBaseExpr)
val errorPostfix = "only literal 10 or 16 are supported for source and target radixes"
(fromBaseLit, toBaseLit) match {
case (Some(Literal(fromBaseVal, IntegerType)), Some(Literal(toBaseVal, IntegerType)))
if Set(fromBaseVal, toBaseVal).subsetOf(Set(10, 16)) => ()
case (Some(Literal(fromBaseVal, IntegerType)), Some(Literal(toBaseVal, IntegerType))) =>
def isBaseSupported(base: Any): Boolean = base == 10 || base == 16
if (!isBaseSupported(fromBaseVal) && !isBaseSupported(toBaseVal)) {
willNotWorkOnGpu(because = s"both ${fromBaseVal} and ${toBaseVal} are not " +
s"a supported radix, ${errorPostfix}")
} else if (!isBaseSupported(fromBaseVal)) {
willNotWorkOnGpu(because = s"${fromBaseVal} is not a supported source radix, " +
s"${errorPostfix}")
} else if (!isBaseSupported(toBaseVal)) {
willNotWorkOnGpu(because = s"${toBaseVal} is not a supported target radix, " +
s"${errorPostfix}")
}
case _ =>
willNotWorkOnGpu(because = "only literal 10 or 16 for from_base and to_base are supported")
// This will never happen in production as the function signature enforces
// integer types for the bases, but nice to have an edge case handling.
willNotWorkOnGpu(because = "either source radix or target radix is not an integer " +
"literal, " + errorPostfix)
}
}

Expand Down

0 comments on commit 6455396

Please sign in to comment.