Skip to content

Commit

Permalink
Fix the cast tests for 3.1.0+ (NVIDIA#1166)
Browse files Browse the repository at this point in the history
Signed-off-by: Robert (Bobby) Evans <bobby@apache.org>
  • Loading branch information
revans2 authored Nov 19, 2020
1 parent da72e97 commit fcef8da
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 27 deletions.
79 changes: 55 additions & 24 deletions tests/src/test/scala/com/nvidia/spark/rapids/AnsiCastOpSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -67,87 +67,106 @@ class AnsiCastOpSuite extends GpuExpressionTestSuite {
// Ansi cast from timestamp to integral types
///////////////////////////////////////////////////////////////////////////

def before3_1_0(s: SparkSession): (Boolean, String) = {
(s.version < "3.1.0", s"Spark version must be prior to 3.1.0")
}

testSparkResultsAreEqual("ansi_cast timestamps to long",
generateValidValuesTimestampsDF(Short.MinValue, Short.MaxValue), sparkConf) {
generateValidValuesTimestampsDF(Short.MinValue, Short.MaxValue), sparkConf,
assumeCondition = before3_1_0) {
frame => testCastTo(DataTypes.LongType)(frame)
}

testSparkResultsAreEqual("ansi_cast successful timestamps to shorts",
generateValidValuesTimestampsDF(Short.MinValue, Short.MaxValue), sparkConf) {
generateValidValuesTimestampsDF(Short.MinValue, Short.MaxValue), sparkConf,
assumeCondition = before3_1_0) {
frame => testCastTo(DataTypes.ShortType)(frame)
}

testSparkResultsAreEqual("ansi_cast successful timestamps to ints",
generateValidValuesTimestampsDF(Int.MinValue, Int.MaxValue), sparkConf) {
generateValidValuesTimestampsDF(Int.MinValue, Int.MaxValue), sparkConf,
assumeCondition = before3_1_0) {
frame => testCastTo(DataTypes.IntegerType)(frame)
}

testSparkResultsAreEqual("ansi_cast successful timestamps to bytes",
generateValidValuesTimestampsDF(Byte.MinValue, Byte.MaxValue), sparkConf) {
generateValidValuesTimestampsDF(Byte.MinValue, Byte.MaxValue), sparkConf,
assumeCondition = before3_1_0) {
frame => testCastTo(DataTypes.ByteType)(frame)
}

testCastFailsForBadInputs("ansi_cast overflow timestamps to bytes",
generateOutOfRangeTimestampsDF(Byte.MinValue, Byte.MaxValue, Byte.MaxValue + 1), sparkConf) {
generateOutOfRangeTimestampsDF(Byte.MinValue, Byte.MaxValue, Byte.MaxValue + 1), sparkConf,
assumeCondition = before3_1_0) {
frame => testCastTo(DataTypes.ByteType)(frame)
}

testCastFailsForBadInputs("ansi_cast underflow timestamps to bytes",
generateOutOfRangeTimestampsDF(Byte.MinValue, Byte.MaxValue, Byte.MinValue - 1), sparkConf) {
generateOutOfRangeTimestampsDF(Byte.MinValue, Byte.MaxValue, Byte.MinValue - 1), sparkConf,
assumeCondition = before3_1_0) {
frame => testCastTo(DataTypes.ByteType)(frame)
}

testCastFailsForBadInputs("ansi_cast overflow timestamps to shorts",
generateOutOfRangeTimestampsDF(Short.MinValue, Short.MaxValue, Short.MaxValue + 1), sparkConf) {
generateOutOfRangeTimestampsDF(Short.MinValue, Short.MaxValue, Short.MaxValue + 1), sparkConf,
assumeCondition = before3_1_0) {
frame => testCastTo(DataTypes.ShortType)(frame)
}

testCastFailsForBadInputs("ansi_cast underflow timestamps to shorts",
generateOutOfRangeTimestampsDF(Short.MinValue, Short.MaxValue, Short.MinValue - 1), sparkConf) {
generateOutOfRangeTimestampsDF(Short.MinValue, Short.MaxValue, Short.MinValue - 1), sparkConf,
assumeCondition = before3_1_0) {
frame => testCastTo(DataTypes.ShortType)(frame)
}

testCastFailsForBadInputs("ansi_cast overflow timestamps to int",
generateOutOfRangeTimestampsDF(Int.MinValue, Int.MaxValue, Int.MaxValue.toLong + 1),
sparkConf) {
sparkConf, assumeCondition = before3_1_0) {
frame => testCastTo(DataTypes.IntegerType)(frame)
}

testCastFailsForBadInputs("ansi_cast underflow timestamps to int",
generateOutOfRangeTimestampsDF(Int.MinValue, Int.MaxValue, Int.MinValue.toLong - 1),
sparkConf) {
sparkConf, assumeCondition = before3_1_0) {
frame => testCastTo(DataTypes.IntegerType)(frame)
}

///////////////////////////////////////////////////////////////////////////
// Ansi cast from date
///////////////////////////////////////////////////////////////////////////

testSparkResultsAreEqual("ansi_cast date to bool", testDates, sparkConf) {
testSparkResultsAreEqual("ansi_cast date to bool", testDates, sparkConf,
assumeCondition = before3_1_0) {
frame => testCastTo(DataTypes.BooleanType)(frame)
}

testSparkResultsAreEqual("ansi_cast date to byte", testDates, sparkConf) {
testSparkResultsAreEqual("ansi_cast date to byte", testDates, sparkConf,
assumeCondition = before3_1_0) {
frame => testCastTo(DataTypes.ByteType)(frame)
}

testSparkResultsAreEqual("ansi_cast date to short", testDates, sparkConf) {
testSparkResultsAreEqual("ansi_cast date to short", testDates, sparkConf,
assumeCondition = before3_1_0) {
frame => testCastTo(DataTypes.ShortType)(frame)
}

testSparkResultsAreEqual("ansi_cast date to int", testDates, sparkConf) {
testSparkResultsAreEqual("ansi_cast date to int", testDates, sparkConf,
assumeCondition = before3_1_0) {
frame => testCastTo(DataTypes.IntegerType)(frame)
}

testSparkResultsAreEqual("ansi_cast date to long", testDates, sparkConf) {
testSparkResultsAreEqual("ansi_cast date to long", testDates, sparkConf,
assumeCondition = before3_1_0) {
frame => testCastTo(DataTypes.LongType)(frame)
}

testSparkResultsAreEqual("ansi_cast date to float", testDates, sparkConf) {
testSparkResultsAreEqual("ansi_cast date to float", testDates, sparkConf,
assumeCondition = before3_1_0) {
frame => testCastTo(DataTypes.FloatType)(frame)
}

testSparkResultsAreEqual("ansi_cast date to double", testDates, sparkConf) {
testSparkResultsAreEqual("ansi_cast date to double", testDates, sparkConf,
assumeCondition = before3_1_0) {
frame => testCastTo(DataTypes.DoubleType)(frame)
}

Expand Down Expand Up @@ -187,7 +206,8 @@ class AnsiCastOpSuite extends GpuExpressionTestSuite {
frame => testCastTo(DataTypes.DoubleType)(frame)
}

testSparkResultsAreEqual("ansi_cast bool to timestamp", testBools, sparkConf) {
testSparkResultsAreEqual("ansi_cast bool to timestamp", testBools, sparkConf,
assumeCondition = before3_1_0) {
frame => testCastTo(DataTypes.TimestampType)(frame)
}

Expand Down Expand Up @@ -219,7 +239,8 @@ class AnsiCastOpSuite extends GpuExpressionTestSuite {
frame => testCastTo(DataTypes.BooleanType)(frame)
}

testSparkResultsAreEqual("ansi_cast timestamp to bool", testTimestamps, sparkConf) {
testSparkResultsAreEqual("ansi_cast timestamp to bool", testTimestamps, sparkConf,
assumeCondition = before3_1_0) {
frame => testCastTo(DataTypes.BooleanType)(frame)
}

Expand Down Expand Up @@ -377,19 +398,23 @@ class AnsiCastOpSuite extends GpuExpressionTestSuite {
// Ansi cast integral types to timestamp
///////////////////////////////////////////////////////////////////////////

testSparkResultsAreEqual("ansi_cast bytes to timestamp", testBytes, sparkConf) {
testSparkResultsAreEqual("ansi_cast bytes to timestamp", testBytes, sparkConf,
assumeCondition = before3_1_0) {
frame => testCastTo(DataTypes.TimestampType)(frame)
}

testSparkResultsAreEqual("ansi_cast shorts to timestamp", testShorts, sparkConf) {
testSparkResultsAreEqual("ansi_cast shorts to timestamp", testShorts, sparkConf,
assumeCondition = before3_1_0) {
frame => testCastTo(DataTypes.TimestampType)(frame)
}

testSparkResultsAreEqual("ansi_cast ints to timestamp", testInts, sparkConf) {
testSparkResultsAreEqual("ansi_cast ints to timestamp", testInts, sparkConf,
assumeCondition = before3_1_0) {
frame => testCastTo(DataTypes.TimestampType)(frame)
}

testSparkResultsAreEqual("ansi_cast longs to timestamp", testLongs, sparkConf) {
testSparkResultsAreEqual("ansi_cast longs to timestamp", testLongs, sparkConf,
assumeCondition = before3_1_0) {
frame => testCastTo(DataTypes.TimestampType)(frame)
}

Expand Down Expand Up @@ -613,10 +638,16 @@ class AnsiCastOpSuite extends GpuExpressionTestSuite {
testName: String,
frame: SparkSession => DataFrame,
sparkConf: SparkConf = sparkConf,
msg: String = GpuCast.INVALID_INPUT_MESSAGE)(transformation: DataFrame => DataFrame)
msg: String = GpuCast.INVALID_INPUT_MESSAGE,
assumeCondition: SparkSession => (Boolean, String) = null)
(transformation: DataFrame => DataFrame)
: Unit = {

test(testName) {
if (assumeCondition != null) {
val (isAllowed, reason) = withCpuSparkSession(assumeCondition, conf = sparkConf)
assume(isAllowed, reason)
}
try {
withGpuSparkSession(spark => {
val input = frame(spark).repartition(1)
Expand Down
18 changes: 16 additions & 2 deletions tests/src/test/scala/com/nvidia/spark/rapids/CastOpSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,17 @@ class CastOpSuite extends GpuExpressionTestSuite {
for (from <- supportedTypes; to <- supportedTypes) yield (from, to)
}

def should310SkipAnsiCast(from: DataType, to: DataType): Boolean = (from, to) match {
case (_: NumericType, TimestampType | DateType) => true
case (BooleanType, TimestampType | DateType) => true
case (TimestampType | DateType, _: NumericType) => true
case (TimestampType | DateType, BooleanType) => true
case _ => false
}


test("Test all supported casts with in-range values") {
val is310OrAfter = !withCpuSparkSession(s => s.version < "3.1.0")

// test cast() and ansi_cast()
Seq(false, true).foreach { ansiEnabled =>
Expand All @@ -63,8 +73,12 @@ class CastOpSuite extends GpuExpressionTestSuite {

typeMatrix.foreach {
case (from, to) =>
// In 3.1.0 Cast.canCast was split with a separate ANSI version
// Until we are on 3.1.0 or more we cannot call this easily so for now
// We will check and skip a very specific one.
val shouldSkip = is310OrAfter && ansiEnabled && should310SkipAnsiCast(to, from)
// check if Spark supports this cast
if (Cast.canCast(from, to)) {
if (!shouldSkip && Cast.canCast(from, to)) {
// check if plugin supports this cast
if (GpuCast.canCast(from, to)) {
// test the cast
Expand All @@ -87,7 +101,7 @@ class CastOpSuite extends GpuExpressionTestSuite {
fail(s"Cast from $from to $to failed; ansi=$ansiEnabled", e)
}
}
} else {
} else if (!shouldSkip) {
// if Spark doesn't support this cast then the plugin shouldn't either
assert(!GpuCast.canCast(from, to))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -698,13 +698,19 @@ trait SparkQueryCompareTestSuite extends FunSuite with Arm {
maxFloatDiff: Double = 0.0,
incompat: Boolean = false,
execsAllowedNonGpu: Seq[String] = Seq.empty,
sortBeforeRepart: Boolean = false)
sortBeforeRepart: Boolean = false,
assumeCondition: SparkSession => (Boolean, String) = null)
(fun: DataFrame => DataFrame): Unit = {

val (testConf, qualifiedTestName) =
setupTestConfAndQualifierName(testName, incompat, sort, conf, execsAllowedNonGpu,
maxFloatDiff, sortBeforeRepart)

test(qualifiedTestName) {
if (assumeCondition != null) {
val (isAllowed, reason) = withCpuSparkSession(assumeCondition, conf = testConf)
assume(isAllowed, reason)
}
val (fromCpu, fromGpu) = runOnCpuAndGpu(df, fun,
conf = testConf,
repart = repart)
Expand Down

0 comments on commit fcef8da

Please sign in to comment.