From 306637358ffef43efd42d19fd01cabfe560ea32b Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Mon, 13 May 2024 19:35:33 +0800 Subject: [PATCH] Add number normalization test and address followup for getJsonObject (#10800) * Add number normalization test and address followup for getJsonObject Signed-off-by: Haoyang Li * Address comment Signed-off-by: Haoyang Li * reduce test cases Signed-off-by: Haoyang Li --------- Signed-off-by: Haoyang Li --- .../src/main/python/get_json_test.py | 25 +++++++++++++++++++ .../spark/rapids/GpuGetJsonObject.scala | 2 +- 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/integration_tests/src/main/python/get_json_test.py b/integration_tests/src/main/python/get_json_test.py index ef405db7e33..cfb1c1420ab 100644 --- a/integration_tests/src/main/python/get_json_test.py +++ b/integration_tests/src/main/python/get_json_test.py @@ -391,3 +391,28 @@ def test_get_json_object_number_normalization_legacy(): conf={'spark.rapids.sql.expression.GetJsonObject': 'true', 'spark.rapids.sql.getJsonObject.legacy.enabled': 'true'}) assert([[row[1]] for row in gpu_result] == data) + +@pytest.mark.parametrize('data_gen', [StringGen(r'''-?[1-9]\d{0,5}\.\d{1,20}''', nullable=False), + StringGen(r'''-?[1-9]\d{0,20}\.\d{1,5}''', nullable=False), + StringGen(r'''-?[1-9]\d{0,5}E-?\d{1,20}''', nullable=False), + StringGen(r'''-?[1-9]\d{0,20}E-?\d{1,5}''', nullable=False)], ids=idfn) +def test_get_json_object_floating_normalization(data_gen): + schema = StructType([StructField("jsonStr", StringType())]) + normalization = lambda spark: unary_op_df(spark, data_gen).selectExpr( + 'a', + 'get_json_object(a,"$")' + ).collect() + gpu_res = [[row[1]] for row in with_gpu_session( + normalization, + conf={'spark.rapids.sql.expression.GetJsonObject': 'true'})] + cpu_res = [[row[1]] for row in with_cpu_session(normalization)] + def json_string_to_float(x): + if x == '"-Infinity"': + return float('-inf') + elif x == '"Infinity"': + return float('inf') + else: + return float(x) + for i in range(len(gpu_res)): + # verify relatively diff < 1e-9 (default value for is_close) + assert math.isclose(json_string_to_float(gpu_res[i][0]), json_string_to_float(cpu_res[i][0])) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuGetJsonObject.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuGetJsonObject.scala index 882c8fec13d..0db1215f2c2 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuGetJsonObject.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuGetJsonObject.scala @@ -89,7 +89,7 @@ object JsonPathParser extends RegexParsers { def fallbackCheck(instructions: List[PathInstruction]): Boolean = { // JNI kernel has a limit of 16 nested nodes, fallback to CPU if we exceed that - instructions.length > 16 + instructions.length > JSONUtils.MAX_PATH_DEPTH } def unzipInstruction(instruction: PathInstruction): (String, String, Long) = {