Skip to content

Commit

Permalink
Enable ANSI mode for CAST string to timestamp (NVIDIA#1555)
Browse files Browse the repository at this point in the history
Signed-off-by: Andy Grove <andygrove@nvidia.com>
  • Loading branch information
andygrove authored Jan 25, 2021
1 parent ef7020c commit ce8b809
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 2 deletions.
22 changes: 21 additions & 1 deletion sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.nvidia.spark.rapids

import java.text.SimpleDateFormat
import java.time.DateTimeException

import ai.rapids.cudf.{ColumnVector, DType, Scalar}

Expand Down Expand Up @@ -784,8 +785,27 @@ case class GpuCast(
convertTimestampOrNull(sanitizedInput, TIMESTAMP_REGEX_YYYY, "%Y"))))

// handle special dates like "epoch", "now", etc.
specialDates.foldLeft(converted)((prev, specialDate) =>
val finalResult = specialDates.foldLeft(converted)((prev, specialDate) =>
specialTimestampOr(sanitizedInput, specialDate._1, specialDate._2, prev))

// When ANSI mode is enabled, we need to throw an exception if any values could not be
// converted
if (ansiMode) {
closeOnExcept(finalResult) { finalResult =>
withResource(input.isNotNull) { wasNotNull =>
withResource(finalResult.isNull) { isNull =>
withResource(wasNotNull.and(isNull)) { notConverted =>
if (notConverted.any().getBoolean) {
throw new DateTimeException(
"One or more values could not be converted to TimestampType")
}
}
}
}
}
}

finalResult
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
package com.nvidia.spark.rapids

import java.sql.Timestamp
import java.time.DateTimeException

import scala.util.Random

import org.apache.spark.SparkConf
import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{Alias, AnsiCast, CastBase}
import org.apache.spark.sql.execution.ProjectExec
Expand All @@ -38,6 +39,7 @@ class AnsiCastOpSuite extends GpuExpressionTestSuite {
.set(RapidsConf.ENABLE_CAST_FLOAT_TO_STRING.key, "true")
.set(RapidsConf.ENABLE_CAST_STRING_TO_INTEGER.key, "true")
.set(RapidsConf.ENABLE_CAST_STRING_TO_FLOAT.key, "true")
.set(RapidsConf.ENABLE_CAST_STRING_TO_TIMESTAMP.key, "true")

def generateOutOfRangeTimestampsDF(
lowerValue: Long,
Expand Down Expand Up @@ -420,6 +422,37 @@ class AnsiCastOpSuite extends GpuExpressionTestSuite {
frame => testCastTo(DataTypes.TimestampType)(frame)
}

test("ANSI mode: cast string to timestamp with parse error") {
// Copied from Spark CastSuite

def checkCastWithParseError(str: String): Unit = {
val exception = intercept[SparkException] {
withGpuSparkSession(spark => {
import spark.implicits._

val df = Seq(str).toDF("c0")
.repartition(2)
.withColumn("c1", col("c0").cast(DataTypes.TimestampType))

val result = df.collect()
result.foreach(println)

}, sparkConf)
}
assert(exception.getCause.isInstanceOf[DateTimeException])
}

checkCastWithParseError("123")
checkCastWithParseError("2015-03-18 123142")
checkCastWithParseError("2015-03-18T123123")
checkCastWithParseError("2015-03-18X")
checkCastWithParseError("2015/03/18")
checkCastWithParseError("2015.03.18")
checkCastWithParseError("20150318")
checkCastWithParseError("2015-031-8")
checkCastWithParseError("2015-03-18T12:03:17-0:70")
}

///////////////////////////////////////////////////////////////////////////
// Writing to Hive tables, which has special rules
///////////////////////////////////////////////////////////////////////////
Expand Down

0 comments on commit ce8b809

Please sign in to comment.