diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala index 6cae021bafa..6a3f8624a8d 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala @@ -17,6 +17,7 @@ package com.nvidia.spark.rapids import java.text.SimpleDateFormat +import java.time.DateTimeException import ai.rapids.cudf.{ColumnVector, DType, Scalar} @@ -784,8 +785,27 @@ case class GpuCast( convertTimestampOrNull(sanitizedInput, TIMESTAMP_REGEX_YYYY, "%Y")))) // handle special dates like "epoch", "now", etc. - specialDates.foldLeft(converted)((prev, specialDate) => + val finalResult = specialDates.foldLeft(converted)((prev, specialDate) => specialTimestampOr(sanitizedInput, specialDate._1, specialDate._2, prev)) + + // When ANSI mode is enabled, we need to throw an exception if any values could not be + // converted + if (ansiMode) { + closeOnExcept(finalResult) { finalResult => + withResource(input.isNotNull) { wasNotNull => + withResource(finalResult.isNull) { isNull => + withResource(wasNotNull.and(isNull)) { notConverted => + if (notConverted.any().getBoolean) { + throw new DateTimeException( + "One or more values could not be converted to TimestampType") + } + } + } + } + } + } + + finalResult } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/AnsiCastOpSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/AnsiCastOpSuite.scala index 81dfee6fa1b..4f0206e657a 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/AnsiCastOpSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/AnsiCastOpSuite.scala @@ -17,10 +17,11 @@ package com.nvidia.spark.rapids import java.sql.Timestamp +import java.time.DateTimeException import scala.util.Random -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.catalyst.expressions.{Alias, AnsiCast, CastBase} import org.apache.spark.sql.execution.ProjectExec @@ -38,6 +39,7 @@ class AnsiCastOpSuite extends GpuExpressionTestSuite { .set(RapidsConf.ENABLE_CAST_FLOAT_TO_STRING.key, "true") .set(RapidsConf.ENABLE_CAST_STRING_TO_INTEGER.key, "true") .set(RapidsConf.ENABLE_CAST_STRING_TO_FLOAT.key, "true") + .set(RapidsConf.ENABLE_CAST_STRING_TO_TIMESTAMP.key, "true") def generateOutOfRangeTimestampsDF( lowerValue: Long, @@ -420,6 +422,37 @@ class AnsiCastOpSuite extends GpuExpressionTestSuite { frame => testCastTo(DataTypes.TimestampType)(frame) } + test("ANSI mode: cast string to timestamp with parse error") { + // Copied from Spark CastSuite + + def checkCastWithParseError(str: String): Unit = { + val exception = intercept[SparkException] { + withGpuSparkSession(spark => { + import spark.implicits._ + + val df = Seq(str).toDF("c0") + .repartition(2) + .withColumn("c1", col("c0").cast(DataTypes.TimestampType)) + + val result = df.collect() + result.foreach(println) + + }, sparkConf) + } + assert(exception.getCause.isInstanceOf[DateTimeException]) + } + + checkCastWithParseError("123") + checkCastWithParseError("2015-03-18 123142") + checkCastWithParseError("2015-03-18T123123") + checkCastWithParseError("2015-03-18X") + checkCastWithParseError("2015/03/18") + checkCastWithParseError("2015.03.18") + checkCastWithParseError("20150318") + checkCastWithParseError("2015-031-8") + checkCastWithParseError("2015-03-18T12:03:17-0:70") + } + /////////////////////////////////////////////////////////////////////////// // Writing to Hive tables, which has special rules ///////////////////////////////////////////////////////////////////////////