diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala index 9c0b2a5de51..d8c87370b50 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala @@ -491,10 +491,30 @@ case class GpuCast( } private def castTimestampToString(input: GpuColumnVector) = { - GpuColumnVector.from( - withResource(input.getBase.asStrings("%Y-%m-%d %H:%M:%S.%3f")) { cv => - cv.stringReplaceWithBackrefs(GpuCast.TIMESTAMP_TRUNCATE_REGEX,"\\1\\2\\3") - }) + // https://github.com/rapidsai/cudf/issues/5166 + // The time is off by 1 second if the result is < 0 + val adjustedTimestamp = withResource(input.getBase + .castTo(DType.TIMESTAMP_MICROSECONDS)) { micros => + withResource(micros.castTo(DType.INT64)) { micros => + withResource(Scalar.fromLong(1000000)) { oneSecond => + withResource(micros.sub(oneSecond)) { subOne => + withResource(Scalar.fromLong(0)) { zero => + withResource(micros.lessThan(zero)) { neg => + neg.ifElse(subOne, micros) + } + } + } + } + } + } + withResource(adjustedTimestamp) { adjustedTimestamp => + withResource(adjustedTimestamp.castTo(DType.TIMESTAMP_MICROSECONDS)) { micros => + withResource(micros.asStrings("%Y-%m-%d %H:%M:%S.%6f")) { cv => + GpuColumnVector.from(cv.stringReplaceWithBackrefs( + GpuCast.TIMESTAMP_TRUNCATE_REGEX, "\\1\\2\\3")) + } + } + } } private def castFloatingTypeToString(input: GpuColumnVector): GpuColumnVector = { diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/CastOpSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/CastOpSuite.scala index b7bd15be42a..7b51e51c846 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/CastOpSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/CastOpSuite.scala @@ -17,7 +17,7 @@ package com.nvidia.spark.rapids import java.sql.Timestamp -import java.text.SimpleDateFormat +import java.time.LocalDateTime import java.util.TimeZone import org.apache.spark.SparkConf @@ -690,9 +690,13 @@ object CastOpSuite { def validTimestamps(session: SparkSession): DataFrame = { import session.sqlContext.implicits._ - val df = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS") val timestampStrings = Seq( - "2020-12-31T11:59:59.999", + "1920-12-31T11:59:59.999", + "1969-12-31T23:59:59.999", + "1969-12-31T23:59:59.999999", + "1970-01-01T00:00:00.000", + "1970-01-01T00:00:00.999", + "1970-01-01T00:00:00.999111", "2020-12-31T11:59:59.990", "2020-12-31T11:59:59.900", "2020-12-31T11:59:59.000", @@ -702,8 +706,7 @@ object CastOpSuite { "2020-12-31T11:00:00.000" ) val timestamps = timestampStrings - .map(s => df.parse(s)) - .map(d => new Timestamp(d.getTime)) + .map(s => Timestamp.valueOf(LocalDateTime.parse(s))) timestamps.toDF("c0") }