Skip to content

Commit

Permalink
Fix collection_ops_tests for Spark 4.0.
Browse files Browse the repository at this point in the history
Fixes #11011.

This commit fixes the failures in `collection_ops_tests` on Spark 4.0.

On all versions of Spark, when a Sequence is collected with rows that exceed MAX_INT,
an exception is thrown indicating that the collected Sequence/array is
larger than permissible. The different versions of Spark vary in the
contents of the exception message.

On Spark 4, one sees that the error message now contains more
information than all prior versions, including:
1. The name of the op causing the error
2. The errant sequence size

This commit introduces a shim to make this new information available in
the exception.

Note that this shim does not fit cleanly in RapidsErrorUtils, because
there are differences within major Spark versions. For instance, Spark
3.4.0-1 have a different message as compared to 3.4.2 and 3.4.3.
Likewise, the differences in 3.5.0, 3.5.1, 3.5.2.
  • Loading branch information
mythrocks committed Aug 30, 2024
1 parent 46057a4 commit 5606503
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 10 deletions.
11 changes: 8 additions & 3 deletions integration_tests/src/main/python/collection_ops_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021-2023, NVIDIA CORPORATION.
# Copyright (c) 2021-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -17,6 +17,8 @@
from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_error
from data_gen import *
from pyspark.sql.types import *

from src.main.python.spark_session import is_before_spark_400
from string_test import mk_str_gen
import pyspark.sql.functions as f
import pyspark.sql.utils
Expand Down Expand Up @@ -326,8 +328,11 @@ def test_sequence_illegal_boundaries(start_gen, stop_gen, step_gen):
@pytest.mark.parametrize('stop_gen', sequence_too_long_length_gens, ids=idfn)
@allow_non_gpu(*non_utc_allow)
def test_sequence_too_long_sequence(stop_gen):
msg = "Too long sequence" if is_before_spark_334() or (not is_before_spark_340() and is_before_spark_342()) \
or is_spark_350() else "Unsuccessful try to create array with"
msg = "Too long sequence" if is_before_spark_334() \
or (not is_before_spark_340() and is_before_spark_342()) \
or is_spark_350() \
else "Can't create array" if not is_before_spark_400() \
else "Unsuccessful try to create array with"
assert_gpu_and_cpu_error(
# To avoid OOM, reduce the row number to 1, it is enough to verify this case.
lambda spark:unary_op_df(spark, stop_gen, 1).selectExpr(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.rapids
import java.util.Optional

import ai.rapids.cudf
import ai.rapids.cudf.{BinaryOp, ColumnVector, ColumnView, DType, Scalar, SegmentedReductionAggregation, Table}
import ai.rapids.cudf.{BinaryOp, ColumnVector, ColumnView, DType, ReductionAggregation, Scalar, SegmentedReductionAggregation, Table}
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.Arm._
import com.nvidia.spark.rapids.ArrayIndexUtils.firstIndexAndNumElementUnchecked
Expand Down Expand Up @@ -1535,7 +1535,8 @@ object GpuSequenceUtil {
def computeSequenceSize(
start: ColumnVector,
stop: ColumnVector,
step: ColumnVector): ColumnVector = {
step: ColumnVector,
functionName: String): ColumnVector = {
checkSequenceInputs(start, stop, step)
val actualSize = GetSequenceSize(start, stop, step)
val sizeAsLong = withResource(actualSize) { _ =>
Expand All @@ -1557,7 +1558,11 @@ object GpuSequenceUtil {
// check max size
withResource(Scalar.fromInt(MAX_ROUNDED_ARRAY_LENGTH)) { maxLen =>
withResource(sizeAsLong.lessOrEqualTo(maxLen)) { allValid =>
require(isAllValidTrue(allValid), GetSequenceSize.TOO_LONG_SEQUENCE)
withResource(sizeAsLong.reduce(ReductionAggregation.max())) { maxSizeScalar =>
require(isAllValidTrue(allValid),
GetSequenceSize.TOO_LONG_SEQUENCE(maxSizeScalar.getLong.asInstanceOf[Int],
functionName))
}
}
}
// cast to int and return
Expand Down Expand Up @@ -1597,7 +1602,7 @@ case class GpuSequence(start: Expression, stop: Expression, stepOpt: Option[Expr
val steps = stepGpuColOpt.map(_.getBase.incRefCount())
.getOrElse(defaultStepsFunc(startCol, stopCol))
closeOnExcept(steps) { _ =>
(computeSequenceSize(startCol, stopCol, steps), steps)
(computeSequenceSize(startCol, stopCol, steps, prettyName), steps)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,12 @@ import com.nvidia.spark.rapids.Arm._
import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH

object GetSequenceSize {
val TOO_LONG_SEQUENCE = s"Too long sequence found. Should be <= $MAX_ROUNDED_ARRAY_LENGTH"
def TOO_LONG_SEQUENCE(sequenceLength: Int, functionName: String) = {
// For these Spark versions, the sequence length and function name
// do not appear in the exception message.
s"Too long sequence found. Should be <= $MAX_ROUNDED_ARRAY_LENGTH"
}

/**
* Compute the size of each sequence according to 'start', 'stop' and 'step'.
* A row (Row[start, stop, step]) contains at least one null element will produce
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@ import ai.rapids.cudf._
import com.nvidia.spark.rapids.Arm._

import org.apache.spark.sql.rapids.{AddOverflowChecks, SubtractOverflowChecks}
import org.apache.spark.sql.rapids.shims.SequenceSizeError
import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH

object GetSequenceSize {
val TOO_LONG_SEQUENCE = "Unsuccessful try to create array with elements exceeding the array " +
s"size limit $MAX_ROUNDED_ARRAY_LENGTH"
def TOO_LONG_SEQUENCE(sequenceLength: Int, functionName: String): String =
SequenceSizeError.getTooLongSequenceErrorString(sequenceLength, functionName)
/**
* Compute the size of each sequence according to 'start', 'stop' and 'step'.
* A row (Row[start, stop, step]) contains at least one null element will produce
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*** spark-rapids-shim-json-lines
{"spark": "334"}
{"spark": "342"}
{"spark": "343"}
{"spark": "351"}
{"spark": "352"}
spark-rapids-shim-json-lines ***/
package org.apache.spark.sql.rapids.shims

import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH

object SequenceSizeError {
def getTooLongSequenceErrorString(sequenceSize: Int, functionName: String): String = {
// The errant function's name does not feature in the exception message
// prior to Spark 4.0. Neither does the attempted allocation size.
"Unsuccessful try to create array with elements exceeding the array " +
s"size limit $MAX_ROUNDED_ARRAY_LENGTH"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*** spark-rapids-shim-json-lines
{"spark": "400"}
spark-rapids-shim-json-lines ***/
package org.apache.spark.sql.rapids.shims

import org.apache.spark.sql.errors.QueryExecutionErrors

object SequenceSizeError {
def getTooLongSequenceErrorString(sequenceSize: Int, functionName: String): String = {
QueryExecutionErrors.createArrayWithElementsExceedLimitError(functionName, sequenceSize).getMessage
}
}

0 comments on commit 5606503

Please sign in to comment.