Skip to content

Commit

Permalink
Better float/double cases for casting tests (#1781)
Browse files Browse the repository at this point in the history
* enhance float/double cases for casting tests

Signed-off-by: sperlingxx <lovedreamf@gmail.com>

* continue

* code clean

* code clean

* fix typo

* fix typo

* some updates

* fix typo
  • Loading branch information
sperlingxx authored Feb 23, 2021
1 parent f45a3b7 commit 52d95b1
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 205 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -338,31 +338,31 @@ class AnsiCastOpSuite extends GpuExpressionTestSuite {
frame => testCastTo(DataTypes.DoubleType)(frame)
}

testCastFailsForBadInputs("Test bad cast 1 from strings to floats", badFloatStringsDf,
testCastFailsForBadInputs("Test bad cast 1 from strings to floats", invalidFloatStringsDf,
msg = GpuCast.INVALID_FLOAT_CAST_MSG) {
frame =>frame.select(col("c0").cast(FloatType))
}

testCastFailsForBadInputs("Test bad cast 2 from strings to floats", badFloatStringsDf,
testCastFailsForBadInputs("Test bad cast 2 from strings to floats", invalidFloatStringsDf,
msg = GpuCast.INVALID_FLOAT_CAST_MSG) {
frame =>frame.select(col("c1").cast(FloatType))
}

testCastFailsForBadInputs("Test bad cast 1 from strings to double", badFloatStringsDf,
testCastFailsForBadInputs("Test bad cast 1 from strings to double", invalidFloatStringsDf,
msg = GpuCast.INVALID_FLOAT_CAST_MSG) {
frame =>frame.select(col("c0").cast(DoubleType))
}

testCastFailsForBadInputs("Test bad cast 2 from strings to double", badFloatStringsDf,
testCastFailsForBadInputs("Test bad cast 2 from strings to double", invalidFloatStringsDf,
msg = GpuCast.INVALID_FLOAT_CAST_MSG) {
frame =>frame.select(col("c1").cast(DoubleType))
}

//Currently there is a bug in cudf which doesn't convert one value correctly
// Currently there is a bug in cudf which doesn't convert some corner cases correctly
// The bug is documented here https://github.com/rapidsai/cudf/issues/5225
ignore("Test cast from strings to double that doesn't match") {
testSparkResultsAreEqual("Test cast from strings to double that doesn't match",
badDoubleStringsDf) {
badDoubleStringsDf, conf = sparkConf, maxFloatDiff = 0.0001) {
frame =>frame.select(
col("c0").cast(DoubleType))
}
Expand All @@ -372,12 +372,12 @@ class AnsiCastOpSuite extends GpuExpressionTestSuite {
// Ansi cast from floating point to string
///////////////////////////////////////////////////////////////////////////

ignore("ansi_cast float to string") {
test("ansi_cast float to string") {
testCastToString[Float](DataTypes.FloatType, ansiMode = true,
comparisonFunc = Some(compareStringifiedFloats))
}

ignore("ansi_cast double to string") {
test("ansi_cast double to string") {
testCastToString[Double](DataTypes.DoubleType, ansiMode = true,
comparisonFunc = Some(compareStringifiedFloats))
}
Expand Down
62 changes: 15 additions & 47 deletions tests/src/test/scala/com/nvidia/spark/rapids/CastOpSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -203,16 +203,6 @@ class CastOpSuite extends GpuExpressionTestSuite {
}
}

private def testCastTo(castTo: DataType)(frame: DataFrame): DataFrame ={
frame.withColumn("c1", col("c0").cast(castTo))
}

private def stringDf(str: String)(session: SparkSession): DataFrame = {
import session.sqlContext.implicits._
// use more than one value otherwise spark optimizes it out as a literal
Seq(str, str).toDF("c0")
}

private def castToStringExpectedFun[T]: T => Option[String] = (d: T) => Some(String.valueOf(d))

test("cast byte to string") {
Expand All @@ -231,11 +221,11 @@ class CastOpSuite extends GpuExpressionTestSuite {
testCastToString[Long](DataTypes.LongType)
}

ignore("cast float to string") {
test("cast float to string") {
testCastToString[Float](DataTypes.FloatType, comparisonFunc = Some(compareStringifiedFloats))
}

ignore("cast double to string") {
test("cast double to string") {
testCastToString[Double](DataTypes.DoubleType, comparisonFunc = Some(compareStringifiedFloats))
}

Expand Down Expand Up @@ -298,30 +288,6 @@ class CastOpSuite extends GpuExpressionTestSuite {
col("doubles").cast(TimestampType))
}

ignore("Test cast from double to string") {

//NOTE that the testSparkResultsAreEqual method isn't adequate in this case because we
// need to use a specialized comparison function

val conf = new SparkConf()
.set(RapidsConf.ENABLE_CAST_FLOAT_TO_STRING.key, "true")

val (cpu, gpu) = runOnCpuAndGpu(doubleDf, frame => frame.select(
col("doubles").cast(StringType))
.orderBy(col("doubles")), conf)

val fromCpu = cpu.map(row => row.getAs[String](0))
val fromGpu = gpu.map(row => row.getAs[String](0))

fromCpu.zip(fromGpu).foreach {
case (c, g) =>
if (!compareStringifiedFloats(c, g)) {
fail(s"Running on the GPU and on the CPU did not match: CPU value: $c. " +
s"GPU value: $g.")
}
}
}

testSparkResultsAreEqual("Test cast from boolean", booleanDf) {
frame => frame.select(
col("bools").cast(IntegerType),
Expand Down Expand Up @@ -396,14 +362,6 @@ class CastOpSuite extends GpuExpressionTestSuite {
col("doubles").cast(TimestampType))
}

ignore("Test cast from strings to double that doesn't match") {
testSparkResultsAreEqual("Test cast from strings to double that doesn't match",
badDoubleStringsDf) {
frame =>frame.select(
col("doubles").cast(DoubleType))
}
}

testSparkResultsAreEqual("Test cast from strings to doubles", doublesAsStrings,
conf = sparkConf, maxFloatDiff = 0.0001) {
frame => frame.select(
Expand All @@ -416,7 +374,7 @@ class CastOpSuite extends GpuExpressionTestSuite {
col("c0").cast(FloatType))
}

testSparkResultsAreEqual("Test bad cast from strings to floats", badFloatStringsDf,
testSparkResultsAreEqual("Test bad cast from strings to floats", invalidFloatStringsDf,
conf = sparkConf, maxFloatDiff = 0.0001) {
frame =>frame.select(
col("c0").cast(DoubleType),
Expand All @@ -425,6 +383,16 @@ class CastOpSuite extends GpuExpressionTestSuite {
col("c1").cast(FloatType))
}

// Currently there is a bug in cudf which doesn't convert some corner cases correctly
// The bug is documented here https://github.com/rapidsai/cudf/issues/5225
ignore("Test cast from strings to double that doesn't match") {
testSparkResultsAreEqual("Test cast from strings to double that doesn't match",
badDoubleStringsDf, conf = sparkConf, maxFloatDiff = 0.0001) {
frame => frame.select(
col("c0").cast(DoubleType))
}
}

testSparkResultsAreEqual("ansi_cast string to double exp", exponentsAsStringsDf,
conf = sparkConf, maxFloatDiff = 0.0001) {
frame => frame.select(
Expand Down Expand Up @@ -720,13 +688,13 @@ object CastOpSuite {

def doublesAsStrings(session: SparkSession): DataFrame = {
val schema = FuzzerUtils.createSchema(Seq(DoubleType), false)
val df = FuzzerUtils.generateDataFrame(session, schema, 100)
val df = FuzzerUtils.generateDataFrame(session, schema, 2048)
df.withColumn("c0", col("c0").cast(StringType))
}

def floatsAsStrings(session: SparkSession): DataFrame = {
val schema = FuzzerUtils.createSchema(Seq(FloatType), false)
val df = FuzzerUtils.generateDataFrame(session, schema, 100)
val df = FuzzerUtils.generateDataFrame(session, schema, 2048)
df.withColumn("c0", col("c0").cast(StringType))
}

Expand Down
16 changes: 10 additions & 6 deletions tests/src/test/scala/com/nvidia/spark/rapids/FuzzerUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -301,30 +301,34 @@ class EnhancedRandom(protected val r: Random, protected val options: FuzzerOptio
}

def nextFloat(): Float = {
r.nextInt(9) match {
r.nextInt(11) match {
case 0 => Float.NaN
case 1 => Float.PositiveInfinity
case 2 => Float.NegativeInfinity
case 3 => r.nextFloat() * Float.MinValue
case 4 => r.nextFloat() * Float.MaxValue
case 3 => Float.MinValue
case 4 => Float.MaxValue
case 5 => 0 - r.nextFloat()
case 6 => r.nextFloat()
case 7 => 0f
case 8 => -0f
case 9 => r.nextFloat() * Float.MinValue
case 10 => r.nextFloat() * Float.MaxValue
}
}

def nextDouble(): Double = {
r.nextInt(9) match {
r.nextInt(11) match {
case 0 => Double.NaN
case 1 => Double.PositiveInfinity
case 2 => Double.NegativeInfinity
case 3 => r.nextDouble() * Double.MinValue
case 4 => r.nextDouble() * Double.MaxValue
case 3 => Double.MaxValue
case 4 => Double.MinValue
case 5 => 0 - r.nextDouble()
case 6 => r.nextDouble()
case 7 => 0d
case 8 => -0d
case 9 => r.nextDouble() * Double.MinValue
case 10 => r.nextDouble() * Double.MaxValue
}
}

Expand Down
Loading

0 comments on commit 52d95b1

Please sign in to comment.