Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the cast tests for 3.1.0+ #1166

Merged
merged 1 commit into from
Nov 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 55 additions & 24 deletions tests/src/test/scala/com/nvidia/spark/rapids/AnsiCastOpSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -67,87 +67,106 @@ class AnsiCastOpSuite extends GpuExpressionTestSuite {
// Ansi cast from timestamp to integral types
///////////////////////////////////////////////////////////////////////////

def before3_1_0(s: SparkSession): (Boolean, String) = {
(s.version < "3.1.0", s"Spark version must be prior to 3.1.0")
}

testSparkResultsAreEqual("ansi_cast timestamps to long",
generateValidValuesTimestampsDF(Short.MinValue, Short.MaxValue), sparkConf) {
generateValidValuesTimestampsDF(Short.MinValue, Short.MaxValue), sparkConf,
assumeCondition = before3_1_0) {
frame => testCastTo(DataTypes.LongType)(frame)
}

testSparkResultsAreEqual("ansi_cast successful timestamps to shorts",
generateValidValuesTimestampsDF(Short.MinValue, Short.MaxValue), sparkConf) {
generateValidValuesTimestampsDF(Short.MinValue, Short.MaxValue), sparkConf,
assumeCondition = before3_1_0) {
frame => testCastTo(DataTypes.ShortType)(frame)
}

testSparkResultsAreEqual("ansi_cast successful timestamps to ints",
generateValidValuesTimestampsDF(Int.MinValue, Int.MaxValue), sparkConf) {
generateValidValuesTimestampsDF(Int.MinValue, Int.MaxValue), sparkConf,
assumeCondition = before3_1_0) {
frame => testCastTo(DataTypes.IntegerType)(frame)
}

testSparkResultsAreEqual("ansi_cast successful timestamps to bytes",
generateValidValuesTimestampsDF(Byte.MinValue, Byte.MaxValue), sparkConf) {
generateValidValuesTimestampsDF(Byte.MinValue, Byte.MaxValue), sparkConf,
assumeCondition = before3_1_0) {
frame => testCastTo(DataTypes.ByteType)(frame)
}

testCastFailsForBadInputs("ansi_cast overflow timestamps to bytes",
generateOutOfRangeTimestampsDF(Byte.MinValue, Byte.MaxValue, Byte.MaxValue + 1), sparkConf) {
generateOutOfRangeTimestampsDF(Byte.MinValue, Byte.MaxValue, Byte.MaxValue + 1), sparkConf,
assumeCondition = before3_1_0) {
frame => testCastTo(DataTypes.ByteType)(frame)
}

testCastFailsForBadInputs("ansi_cast underflow timestamps to bytes",
generateOutOfRangeTimestampsDF(Byte.MinValue, Byte.MaxValue, Byte.MinValue - 1), sparkConf) {
generateOutOfRangeTimestampsDF(Byte.MinValue, Byte.MaxValue, Byte.MinValue - 1), sparkConf,
assumeCondition = before3_1_0) {
frame => testCastTo(DataTypes.ByteType)(frame)
}

testCastFailsForBadInputs("ansi_cast overflow timestamps to shorts",
generateOutOfRangeTimestampsDF(Short.MinValue, Short.MaxValue, Short.MaxValue + 1), sparkConf) {
generateOutOfRangeTimestampsDF(Short.MinValue, Short.MaxValue, Short.MaxValue + 1), sparkConf,
assumeCondition = before3_1_0) {
frame => testCastTo(DataTypes.ShortType)(frame)
}

testCastFailsForBadInputs("ansi_cast underflow timestamps to shorts",
generateOutOfRangeTimestampsDF(Short.MinValue, Short.MaxValue, Short.MinValue - 1), sparkConf) {
generateOutOfRangeTimestampsDF(Short.MinValue, Short.MaxValue, Short.MinValue - 1), sparkConf,
assumeCondition = before3_1_0) {
frame => testCastTo(DataTypes.ShortType)(frame)
}

testCastFailsForBadInputs("ansi_cast overflow timestamps to int",
generateOutOfRangeTimestampsDF(Int.MinValue, Int.MaxValue, Int.MaxValue.toLong + 1),
sparkConf) {
sparkConf, assumeCondition = before3_1_0) {
frame => testCastTo(DataTypes.IntegerType)(frame)
}

testCastFailsForBadInputs("ansi_cast underflow timestamps to int",
generateOutOfRangeTimestampsDF(Int.MinValue, Int.MaxValue, Int.MinValue.toLong - 1),
sparkConf) {
sparkConf, assumeCondition = before3_1_0) {
frame => testCastTo(DataTypes.IntegerType)(frame)
}

///////////////////////////////////////////////////////////////////////////
// Ansi cast from date
///////////////////////////////////////////////////////////////////////////

testSparkResultsAreEqual("ansi_cast date to bool", testDates, sparkConf) {
testSparkResultsAreEqual("ansi_cast date to bool", testDates, sparkConf,
assumeCondition = before3_1_0) {
frame => testCastTo(DataTypes.BooleanType)(frame)
}

testSparkResultsAreEqual("ansi_cast date to byte", testDates, sparkConf) {
testSparkResultsAreEqual("ansi_cast date to byte", testDates, sparkConf,
assumeCondition = before3_1_0) {
frame => testCastTo(DataTypes.ByteType)(frame)
}

testSparkResultsAreEqual("ansi_cast date to short", testDates, sparkConf) {
testSparkResultsAreEqual("ansi_cast date to short", testDates, sparkConf,
assumeCondition = before3_1_0) {
frame => testCastTo(DataTypes.ShortType)(frame)
}

testSparkResultsAreEqual("ansi_cast date to int", testDates, sparkConf) {
testSparkResultsAreEqual("ansi_cast date to int", testDates, sparkConf,
assumeCondition = before3_1_0) {
frame => testCastTo(DataTypes.IntegerType)(frame)
}

testSparkResultsAreEqual("ansi_cast date to long", testDates, sparkConf) {
testSparkResultsAreEqual("ansi_cast date to long", testDates, sparkConf,
assumeCondition = before3_1_0) {
frame => testCastTo(DataTypes.LongType)(frame)
}

testSparkResultsAreEqual("ansi_cast date to float", testDates, sparkConf) {
testSparkResultsAreEqual("ansi_cast date to float", testDates, sparkConf,
assumeCondition = before3_1_0) {
frame => testCastTo(DataTypes.FloatType)(frame)
}

testSparkResultsAreEqual("ansi_cast date to double", testDates, sparkConf) {
testSparkResultsAreEqual("ansi_cast date to double", testDates, sparkConf,
assumeCondition = before3_1_0) {
frame => testCastTo(DataTypes.DoubleType)(frame)
}

Expand Down Expand Up @@ -187,7 +206,8 @@ class AnsiCastOpSuite extends GpuExpressionTestSuite {
frame => testCastTo(DataTypes.DoubleType)(frame)
}

testSparkResultsAreEqual("ansi_cast bool to timestamp", testBools, sparkConf) {
testSparkResultsAreEqual("ansi_cast bool to timestamp", testBools, sparkConf,
assumeCondition = before3_1_0) {
frame => testCastTo(DataTypes.TimestampType)(frame)
}

Expand Down Expand Up @@ -219,7 +239,8 @@ class AnsiCastOpSuite extends GpuExpressionTestSuite {
frame => testCastTo(DataTypes.BooleanType)(frame)
}

testSparkResultsAreEqual("ansi_cast timestamp to bool", testTimestamps, sparkConf) {
testSparkResultsAreEqual("ansi_cast timestamp to bool", testTimestamps, sparkConf,
assumeCondition = before3_1_0) {
frame => testCastTo(DataTypes.BooleanType)(frame)
}

Expand Down Expand Up @@ -377,19 +398,23 @@ class AnsiCastOpSuite extends GpuExpressionTestSuite {
// Ansi cast integral types to timestamp
///////////////////////////////////////////////////////////////////////////

testSparkResultsAreEqual("ansi_cast bytes to timestamp", testBytes, sparkConf) {
testSparkResultsAreEqual("ansi_cast bytes to timestamp", testBytes, sparkConf,
assumeCondition = before3_1_0) {
frame => testCastTo(DataTypes.TimestampType)(frame)
}

testSparkResultsAreEqual("ansi_cast shorts to timestamp", testShorts, sparkConf) {
testSparkResultsAreEqual("ansi_cast shorts to timestamp", testShorts, sparkConf,
assumeCondition = before3_1_0) {
frame => testCastTo(DataTypes.TimestampType)(frame)
}

testSparkResultsAreEqual("ansi_cast ints to timestamp", testInts, sparkConf) {
testSparkResultsAreEqual("ansi_cast ints to timestamp", testInts, sparkConf,
assumeCondition = before3_1_0) {
frame => testCastTo(DataTypes.TimestampType)(frame)
}

testSparkResultsAreEqual("ansi_cast longs to timestamp", testLongs, sparkConf) {
testSparkResultsAreEqual("ansi_cast longs to timestamp", testLongs, sparkConf,
assumeCondition = before3_1_0) {
frame => testCastTo(DataTypes.TimestampType)(frame)
}

Expand Down Expand Up @@ -613,10 +638,16 @@ class AnsiCastOpSuite extends GpuExpressionTestSuite {
testName: String,
frame: SparkSession => DataFrame,
sparkConf: SparkConf = sparkConf,
msg: String = GpuCast.INVALID_INPUT_MESSAGE)(transformation: DataFrame => DataFrame)
msg: String = GpuCast.INVALID_INPUT_MESSAGE,
assumeCondition: SparkSession => (Boolean, String) = null)
(transformation: DataFrame => DataFrame)
: Unit = {

test(testName) {
if (assumeCondition != null) {
val (isAllowed, reason) = withCpuSparkSession(assumeCondition, conf = sparkConf)
assume(isAllowed, reason)
}
try {
withGpuSparkSession(spark => {
val input = frame(spark).repartition(1)
Expand Down
18 changes: 16 additions & 2 deletions tests/src/test/scala/com/nvidia/spark/rapids/CastOpSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,17 @@ class CastOpSuite extends GpuExpressionTestSuite {
for (from <- supportedTypes; to <- supportedTypes) yield (from, to)
}

def should310SkipAnsiCast(from: DataType, to: DataType): Boolean = (from, to) match {
case (_: NumericType, TimestampType | DateType) => true
case (BooleanType, TimestampType | DateType) => true
case (TimestampType | DateType, _: NumericType) => true
case (TimestampType | DateType, BooleanType) => true
case _ => false
}


test("Test all supported casts with in-range values") {
val is310OrAfter = !withCpuSparkSession(s => s.version < "3.1.0")

// test cast() and ansi_cast()
Seq(false, true).foreach { ansiEnabled =>
Expand All @@ -63,8 +73,12 @@ class CastOpSuite extends GpuExpressionTestSuite {

typeMatrix.foreach {
case (from, to) =>
// In 3.1.0 Cast.canCast was split with a separate ANSI version
// Until we are on 3.1.0 or more we cannot call this easily so for now
// We will check and skip a very specific one.
val shouldSkip = is310OrAfter && ansiEnabled && should310SkipAnsiCast(to, from)
// check if Spark supports this cast
if (Cast.canCast(from, to)) {
if (!shouldSkip && Cast.canCast(from, to)) {
// check if plugin supports this cast
if (GpuCast.canCast(from, to)) {
// test the cast
Expand All @@ -87,7 +101,7 @@ class CastOpSuite extends GpuExpressionTestSuite {
fail(s"Cast from $from to $to failed; ansi=$ansiEnabled", e)
}
}
} else {
} else if (!shouldSkip) {
// if Spark doesn't support this cast then the plugin shouldn't either
assert(!GpuCast.canCast(from, to))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -698,13 +698,19 @@ trait SparkQueryCompareTestSuite extends FunSuite with Arm {
maxFloatDiff: Double = 0.0,
incompat: Boolean = false,
execsAllowedNonGpu: Seq[String] = Seq.empty,
sortBeforeRepart: Boolean = false)
sortBeforeRepart: Boolean = false,
assumeCondition: SparkSession => (Boolean, String) = null)
(fun: DataFrame => DataFrame): Unit = {

val (testConf, qualifiedTestName) =
setupTestConfAndQualifierName(testName, incompat, sort, conf, execsAllowedNonGpu,
maxFloatDiff, sortBeforeRepart)

test(qualifiedTestName) {
if (assumeCondition != null) {
val (isAllowed, reason) = withCpuSparkSession(assumeCondition, conf = testConf)
assume(isAllowed, reason)
}
val (fromCpu, fromGpu) = runOnCpuAndGpu(df, fun,
conf = testConf,
repart = repart)
Expand Down