Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable to_date (via gettimestamp and casting timestamp to date) for non-UTC time zones #10100

Merged
merged 15 commits into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions integration_tests/src/main/python/cast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,11 @@ def test_cast_timestamp_to_string():
lambda spark: unary_op_df(spark, timestamp_gen)
.selectExpr("cast(a as string)"))

def test_cast_timestamp_to_date():
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, timestamp_gen)
.selectExpr("cast(a as date)"))

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
def test_cast_day_time_interval_to_string():
_assert_cast_to_string_equal(DayTimeIntervalGen(start_field='day', end_field='day', special_cases=[MIN_DAY_TIME_INTERVAL, MAX_DAY_TIME_INTERVAL, timedelta(seconds=0)]), {})
Expand Down Expand Up @@ -692,9 +697,9 @@ def test_cast_int_to_string_not_UTC():
lambda spark: unary_op_df(spark, int_gen, 100).selectExpr("a", "CAST(a AS STRING) as str"),
{"spark.sql.session.timeZone": "+08"})

not_utc_fallback_test_params = [(timestamp_gen, 'STRING'), (timestamp_gen, 'DATE'),
not_utc_fallback_test_params = [(timestamp_gen, 'STRING'),
# python does not like year 0, and with time zones the default start date can become year 0 :(
(DateGen(start=date(1, 1, 3)), 'TIMESTAMP'),
(DateGen(start=date(1, 1, 1)), 'TIMESTAMP'),
(SetValuesGen(StringType(), ['2023-03-20 10:38:50', '2023-03-20 10:39:02']), 'TIMESTAMP')]

@allow_non_gpu('ProjectExec')
Expand Down
2 changes: 1 addition & 1 deletion integration_tests/src/main/python/data_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from spark_session import is_before_spark_340, with_cpu_session
import sre_yield
import struct
from conftest import skip_unless_precommit_tests,get_datagen_seed, is_not_utc
from conftest import skip_unless_precommit_tests, get_datagen_seed, is_not_utc
import time
import os
from functools import lru_cache
Expand Down
18 changes: 9 additions & 9 deletions integration_tests/src/main/python/date_time_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ def fun(spark):

assert_gpu_and_cpu_are_equal_collect(fun, conf=copy_and_update(parser_policy_dic, ansi_enabled_conf))


@pytest.mark.parametrize('ansi_enabled', [True, False], ids=['ANSI_ON', 'ANSI_OFF'])
@pytest.mark.parametrize('data_gen', date_n_time_gens, ids=idfn)
@tz_sensitive_test
Expand Down Expand Up @@ -427,22 +428,21 @@ def test_string_unix_timestamp_ansi_exception():
error_message="Exception",
conf=ansi_enabled_conf)

@pytest.mark.parametrize('data_gen', [StringGen('[0-9]{4}-0[1-9]-[0-2][1-8]')], ids=idfn)
@pytest.mark.parametrize('ansi_enabled', [True, False], ids=['ANSI_ON', 'ANSI_OFF'])
@allow_non_gpu(*non_utc_allow)
def test_gettimestamp(data_gen, ansi_enabled):
@pytest.mark.parametrize("ansi_enabled", [True, False], ids=['ANSI_ON', 'ANSI_OFF'])
@tz_sensitive_test
def test_to_date(ansi_enabled):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).select(f.to_date(f.col("a"), "yyyy-MM-dd")),
lambda spark : unary_op_df(spark, date_gen)
.select(f.to_date(f.col("a").cast('string'), "yyyy-MM-dd")),
{'spark.sql.ansi.enabled': ansi_enabled})


@pytest.mark.parametrize('data_gen', [StringGen('0[1-9][0-9]{4}')], ids=idfn)
@allow_non_gpu(*non_utc_allow)
def test_gettimestamp_format_MMyyyy(data_gen):
@tz_sensitive_test
def test_to_date_format_MMyyyy(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen).select(f.to_date(f.col("a"), "MMyyyy")))

def test_gettimestamp_ansi_exception():
def test_to_date_ansi_exception():
assert_gpu_and_cpu_error(
lambda spark : invalid_date_string_df(spark).select(f.to_date(f.col("a"), "yyyy-MM-dd")).collect(),
error_message="Exception",
Expand Down
1 change: 0 additions & 1 deletion integration_tests/src/main/python/window_function_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1682,7 +1682,6 @@ def test_window_first_last_nth_ignore_nulls(data_gen):


@ignore_order(local=True)
@allow_non_gpu(*non_utc_allow)
def test_to_date_with_window_functions():
"""
This test ensures that date expressions participating alongside window aggregations
Expand Down
24 changes: 21 additions & 3 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,14 @@ import ai.rapids.cudf.{BinaryOp, CaptureGroups, ColumnVector, ColumnView, Decima
import ai.rapids.cudf
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.jni.CastStrings
import com.nvidia.spark.rapids.jni.{CastStrings, GpuTimeZoneDB}
import com.nvidia.spark.rapids.shims.{AnsiUtil, GpuCastShims, GpuIntervalUtils, GpuTypeShims, SparkShimImpl, YearParseUtil}
import org.apache.commons.text.StringEscapeUtils

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, NullIntolerant, TimeZoneAwareExpression, UnaryExpression}
import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_SECOND
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids.GpuToTimestamp.replaceSpecialDates
import org.apache.spark.sql.rapids.shims.RapidsErrorUtils
Expand Down Expand Up @@ -86,6 +87,13 @@ abstract class CastExprMetaBase[INPUT <: UnaryExpression with TimeZoneAwareExpre
val fromType: DataType = cast.child.dataType
val toType: DataType = cast.dataType

override def isTimeZoneSupported: Boolean = {
(fromType, toType) match {
case (TimestampType, DateType) => true // this is for to_date(...)
case _ => false
}
}

override def tagExprForGpu(): Unit = {
recursiveTagExprForGpuCheck()
}
Expand Down Expand Up @@ -209,13 +217,16 @@ object CastOptions {
* @param ansiMode Whether the cast should be ANSI compliant
* @param stringToDateAnsiMode Whether to cast String to Date using ANSI compliance
* @param castToJsonString Whether to use JSON format when casting to String
* @param ignoreNullFieldsInStructs Whether to omit null values when converting to JSON
* @param timeZoneId If cast is timezone aware, the timezone needed
*/
class CastOptions(
legacyCastComplexTypesToString: Boolean,
ansiMode: Boolean,
stringToDateAnsiMode: Boolean,
val castToJsonString: Boolean = false,
val ignoreNullFieldsInStructs: Boolean = true) extends Serializable {
val ignoreNullFieldsInStructs: Boolean = true,
val timeZoneId: Option[String] = Option.empty[String]) extends Serializable {

/**
* Retuns the left bracket to use when surrounding brackets when converting
Expand Down Expand Up @@ -614,6 +625,12 @@ object GpuCast {
case (_: IntegerType | ShortType | ByteType, ym: DataType)
if GpuTypeShims.isSupportedYearMonthType(ym) =>
GpuIntervalUtils.intToYearMonthInterval(input, ym)
case (TimestampType, DateType) if options.timeZoneId.isDefined =>
val zoneId = DateTimeUtils.getZoneId(options.timeZoneId.get)
withResource(GpuTimeZoneDB.fromUtcTimestampToTimestamp(input.asInstanceOf[ColumnVector],
zoneId.normalized())) {
shifted => shifted.castTo(GpuColumnVector.getNonNestedRapidsType(toDataType))
}
case _ =>
input.castTo(GpuColumnVector.getNonNestedRapidsType(toDataType))
}
Expand Down Expand Up @@ -1807,7 +1824,8 @@ case class GpuCast(
import GpuCast._

private val options: CastOptions =
new CastOptions(legacyCastComplexTypesToString, ansiMode, stringToDateAnsiModeEnabled)
new CastOptions(legacyCastComplexTypesToString, ansiMode, stringToDateAnsiModeEnabled,
timeZoneId = timeZoneId)

// when ansi mode is enabled, some cast expressions can throw exceptions on invalid inputs
override def hasSideEffects: Boolean = super.hasSideEffects || {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1126,8 +1126,9 @@ abstract class BaseExprMeta[INPUT <: Expression](
if (!isTimeZoneSupported) return checkUTCTimezone(this)

// Level 3 check
if (!GpuTimeZoneDB.isSupportedTimeZone(getZoneId())) {
willNotWorkOnGpu(TimeZoneDB.timezoneNotSupportedStr(this.wrapped.getClass.toString))
val zoneId = getZoneId()
if (!GpuTimeZoneDB.isSupportedTimeZone(zoneId)) {
willNotWorkOnGpu(TimeZoneDB.timezoneNotSupportedStr(zoneId.toString))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@ object TimeStamp {
.withPsNote(TypeEnum.STRING, "A limited number of formats are supported"),
TypeSig.STRING)),
(a, conf, p, r) => new UnixTimeExprMeta[GetTimestamp](a, conf, p, r) {
override def isTimeZoneSupported = true
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = {
GpuGetTimestamp(lhs, rhs, sparkFormat, strfFormat)
GpuGetTimestamp(lhs, rhs, sparkFormat, strfFormat, a.timeZoneId)
}
})
).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap
Expand Down