Skip to content

Commit

Permalink
Fix rounding error when casting timestamp to string for timestamps be…
Browse files Browse the repository at this point in the history
…fore 1970 (NVIDIA#893)

* Fix error casting pre-1970s timestamps to string

Signed-off-by: Andy Grove <andygrove@nvidia.com>

* Add reference to cuDF issue

Signed-off-by: Andy Grove <andygrove@nvidia.com>

* release another vector

Signed-off-by: Andy Grove <andygrove@nvidia.com>

* keep timestamp as microseconds

* rename variables

Signed-off-by: Andy Grove <andygrove@nvidia.com>

* rename variable

Signed-off-by: Andy Grove <andygrove@nvidia.com>

* revert accidental change

* support timestamps with microseconds

Signed-off-by: Andy Grove <andygrove@nvidia.com>
  • Loading branch information
andygrove authored Oct 2, 2020
1 parent 34c2d85 commit 34c01a1
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 9 deletions.
28 changes: 24 additions & 4 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -491,10 +491,30 @@ case class GpuCast(
}

private def castTimestampToString(input: GpuColumnVector) = {
GpuColumnVector.from(
withResource(input.getBase.asStrings("%Y-%m-%d %H:%M:%S.%3f")) { cv =>
cv.stringReplaceWithBackrefs(GpuCast.TIMESTAMP_TRUNCATE_REGEX,"\\1\\2\\3")
})
// https://github.com/rapidsai/cudf/issues/5166
// The time is off by 1 second if the result is < 0
val adjustedTimestamp = withResource(input.getBase
.castTo(DType.TIMESTAMP_MICROSECONDS)) { micros =>
withResource(micros.castTo(DType.INT64)) { micros =>
withResource(Scalar.fromLong(1000000)) { oneSecond =>
withResource(micros.sub(oneSecond)) { subOne =>
withResource(Scalar.fromLong(0)) { zero =>
withResource(micros.lessThan(zero)) { neg =>
neg.ifElse(subOne, micros)
}
}
}
}
}
}
withResource(adjustedTimestamp) { adjustedTimestamp =>
withResource(adjustedTimestamp.castTo(DType.TIMESTAMP_MICROSECONDS)) { micros =>
withResource(micros.asStrings("%Y-%m-%d %H:%M:%S.%6f")) { cv =>
GpuColumnVector.from(cv.stringReplaceWithBackrefs(
GpuCast.TIMESTAMP_TRUNCATE_REGEX, "\\1\\2\\3"))
}
}
}
}

private def castFloatingTypeToString(input: GpuColumnVector): GpuColumnVector = {
Expand Down
13 changes: 8 additions & 5 deletions tests/src/test/scala/com/nvidia/spark/rapids/CastOpSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package com.nvidia.spark.rapids

import java.sql.Timestamp
import java.text.SimpleDateFormat
import java.time.LocalDateTime
import java.util.TimeZone

import org.apache.spark.SparkConf
Expand Down Expand Up @@ -690,9 +690,13 @@ object CastOpSuite {

def validTimestamps(session: SparkSession): DataFrame = {
import session.sqlContext.implicits._
val df = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS")
val timestampStrings = Seq(
"2020-12-31T11:59:59.999",
"1920-12-31T11:59:59.999",
"1969-12-31T23:59:59.999",
"1969-12-31T23:59:59.999999",
"1970-01-01T00:00:00.000",
"1970-01-01T00:00:00.999",
"1970-01-01T00:00:00.999111",
"2020-12-31T11:59:59.990",
"2020-12-31T11:59:59.900",
"2020-12-31T11:59:59.000",
Expand All @@ -702,8 +706,7 @@ object CastOpSuite {
"2020-12-31T11:00:00.000"
)
val timestamps = timestampStrings
.map(s => df.parse(s))
.map(d => new Timestamp(d.getTime))
.map(s => Timestamp.valueOf(LocalDateTime.parse(s)))

timestamps.toDF("c0")
}
Expand Down

0 comments on commit 34c01a1

Please sign in to comment.