Skip to content

Commit

Permalink
Added TimeAdd (NVIDIA#675)
Browse files Browse the repository at this point in the history
* Added TimeAdd

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* addressed review comments

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

Co-authored-by: Raza Jafri <rjafri@nvidia.com>
  • Loading branch information
razajafri and razajafri authored Sep 8, 2020
1 parent f7edbe0 commit 7910229
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 37 deletions.
1 change: 1 addition & 0 deletions docs/configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ Name | SQL Function(s) | Description | Default Value | Notes
<a name="sql.expression.Subtract"></a>spark.rapids.sql.expression.Subtract|`-`|Subtraction|true|None|
<a name="sql.expression.Tan"></a>spark.rapids.sql.expression.Tan|`tan`|Tangent|true|None|
<a name="sql.expression.Tanh"></a>spark.rapids.sql.expression.Tanh|`tanh`|Hyperbolic tangent|true|None|
<a name="sql.expression.TimeAdd"></a>spark.rapids.sql.expression.TimeAdd| |Adds interval to timestamp|true|None|
<a name="sql.expression.TimeSub"></a>spark.rapids.sql.expression.TimeSub| |Subtracts interval from timestamp|true|None|
<a name="sql.expression.ToDegrees"></a>spark.rapids.sql.expression.ToDegrees|`degrees`|Converts radians to degrees|true|None|
<a name="sql.expression.ToRadians"></a>spark.rapids.sql.expression.ToRadians|`radians`|Converts degrees to radians|true|None|
Expand Down
12 changes: 9 additions & 3 deletions integration_tests/src/main/python/date_time_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,6 @@
# We only support literal intervals for TimeSub
vals = [(-584, 1563), (1943, 1101), (2693, 2167), (2729, 0), (44, 1534), (2635, 3319),
(1885, -2828), (0, 2463), (932, 2286), (0, 0)]
@pytest.mark.xfail(
condition=not(is_before_spark_310()),
reason='https://issues.apache.org/jira/browse/SPARK-32640')
@pytest.mark.parametrize('data_gen', vals, ids=idfn)
def test_timesub(data_gen):
days, seconds = data_gen
Expand All @@ -35,6 +32,15 @@ def test_timesub(data_gen):
lambda spark: unary_op_df(spark, TimestampGen(start=datetime(15, 1, 1, tzinfo=timezone.utc)), seed=1)
.selectExpr("a - (interval {} days {} seconds)".format(days, seconds)))

@pytest.mark.parametrize('data_gen', vals, ids=idfn)
def test_timeadd(data_gen):
days, seconds = data_gen
assert_gpu_and_cpu_are_equal_collect(
# We are starting at year 0005 to make sure we don't go before year 0001
# and beyond year 10000 while doing TimeAdd
lambda spark: unary_op_df(spark, TimestampGen(start=datetime(5, 1, 1, tzinfo=timezone.utc), end=datetime(15, 1, 1, tzinfo=timezone.utc)), seed=1)
.selectExpr("a + (interval {} days {} seconds)".format(days, seconds)))

@pytest.mark.parametrize('data_gen', date_gens, ids=idfn)
def test_datediff(data_gen):
assert_gpu_and_cpu_are_equal_collect(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,11 @@ import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, HashJoin, SortMergeJoinExec}
import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec
import org.apache.spark.sql.rapids.{GpuFileSourceScanExec, GpuTimeSub, ShuffleManagerShimBase}
import org.apache.spark.sql.rapids.{GpuFileSourceScanExec, ShuffleManagerShimBase}
import org.apache.spark.sql.rapids.execution.GpuBroadcastNestedLoopJoinExecBase
import org.apache.spark.sql.rapids.shims.spark310._
import org.apache.spark.sql.types._
import org.apache.spark.storage.{BlockId, BlockManagerId}
import org.apache.spark.unsafe.types.CalendarInterval

class Spark310Shims extends Spark301Shims {

Expand Down Expand Up @@ -102,30 +101,7 @@ class Spark310Shims extends Spark301Shims {
}

override def getExprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = {
val exprs310: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = Seq(
GpuOverrides.expr[TimeAdd](
"Subtracts interval from timestamp",
(a, conf, p, r) => new BinaryExprMeta[TimeAdd](a, conf, p, r) {
override def tagExprForGpu(): Unit = {
a.interval match {
case Literal(intvl: CalendarInterval, DataTypes.CalendarIntervalType) =>
if (intvl.months != 0) {
willNotWorkOnGpu("interval months isn't supported")
}
case _ =>
willNotWorkOnGpu("only literals are supported for intervals")
}
if (ZoneId.of(a.timeZoneId.get).normalized() != GpuOverrides.UTC_TIMEZONE_ID) {
willNotWorkOnGpu("Only UTC zone id is supported")
}
}

override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
GpuTimeSub(lhs, rhs)
}
)
).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap
exprs310 ++ super.exprs301
super.exprs301
}

override def getExecs: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ import org.apache.spark.sql.rapids._
import org.apache.spark.sql.rapids.catalyst.expressions.GpuRand
import org.apache.spark.sql.rapids.execution.{GpuBroadcastMeta, GpuBroadcastNestedLoopJoinMeta, GpuCustomShuffleReaderExec, GpuShuffleMeta}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}

/**
* Base class for all ReplacementRules
Expand Down Expand Up @@ -907,6 +907,28 @@ object GpuOverrides {
GpuDateDiff(lhs, rhs)
}
}),
expr[TimeAdd](
"Adds interval to timestamp",
(a, conf, p, r) => new BinaryExprMeta[TimeAdd](a, conf, p, r) {
override def tagExprForGpu(): Unit = {
a.interval match {
case Literal(intvl: CalendarInterval, DataTypes.CalendarIntervalType) =>
if (intvl.months != 0) {
willNotWorkOnGpu("interval months isn't supported")
}
case _ =>
willNotWorkOnGpu("only literals are supported for intervals")
}
if (ZoneId.of(a.timeZoneId.get).normalized() != GpuOverrides.UTC_TIMEZONE_ID) {
willNotWorkOnGpu("Only UTC zone id is supported")
}
}

override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = {
GpuTimeAdd(lhs, rhs)
}
}
),
expr[ToUnixTimestamp](
"Returns the UNIX timestamp of the given time",
(a, conf, p, r) => new UnixTimeExprMeta[ToUnixTimestamp](a, conf, p, r){
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,12 @@ case class GpuYear(child: Expression) extends GpuDateUnaryExpression {
GpuColumnVector.from(input.getBase.year())
}

case class GpuTimeSub(
abstract class GpuTimeMath(
start: Expression,
interval: Expression,
timeZoneId: Option[String] = None)
extends BinaryExpression with GpuExpression with TimeZoneAwareExpression with ExpectsInputTypes {
extends BinaryExpression with GpuExpression with TimeZoneAwareExpression with ExpectsInputTypes
with Serializable {

def this(start: Expression, interval: Expression) = this(start, interval, None)

Expand All @@ -136,10 +137,6 @@ case class GpuTimeSub(

override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess

override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = {
copy(timeZoneId = Option(timeZoneId))
}

override def columnarEval(batch: ColumnarBatch): Any = {
var lhs: Any = null
var rhs: Any = null
Expand All @@ -156,7 +153,7 @@ case class GpuTimeSub(
if (usToSub != 0) {
withResource(Scalar.fromLong(usToSub)) { us_s =>
withResource(l.getBase.castTo(DType.INT64)) { us =>
withResource(us.sub(us_s)) {longResult =>
withResource(intervalMath(us_s, us)) { longResult =>
GpuColumnVector.from(longResult.castTo(DType.TIMESTAMP_MICROSECONDS))
}
}
Expand All @@ -177,6 +174,36 @@ case class GpuTimeSub(
}
}
}

def intervalMath(us_s: Scalar, us: ColumnVector): ColumnVector
}

case class GpuTimeAdd(start: Expression,
interval: Expression,
timeZoneId: Option[String] = None)
extends GpuTimeMath(start, interval, timeZoneId) {

override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = {
copy(timeZoneId = Option(timeZoneId))
}

override def intervalMath(us_s: Scalar, us: ColumnVector): ColumnVector = {
us.add(us_s)
}
}

case class GpuTimeSub(start: Expression,
interval: Expression,
timeZoneId: Option[String] = None)
extends GpuTimeMath(start, interval, timeZoneId) {

override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = {
copy(timeZoneId = Option(timeZoneId))
}

def intervalMath(us_s: Scalar, us: ColumnVector): ColumnVector = {
us.sub(us_s)
}
}

case class GpuDateDiff(endDate: Expression, startDate: Expression)
Expand Down

0 comments on commit 7910229

Please sign in to comment.