Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added TimeAdd #675

Merged
merged 2 commits into from
Sep 8, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions integration_tests/src/main/python/date_time_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,15 @@ def test_timesub(data_gen):
lambda spark: unary_op_df(spark, TimestampGen(start=datetime(15, 1, 1, tzinfo=timezone.utc)), seed=1)
.selectExpr("a - (interval {} days {} seconds)".format(days, seconds)))

@pytest.mark.parametrize('data_gen', vals, ids=idfn)
def test_timeadd(data_gen):
days, seconds = data_gen
assert_gpu_and_cpu_are_equal_collect(
# We are starting at year 0005 to make sure we don't go before year 0001
# and beyond year 10000 while doing TimeAdd
lambda spark: unary_op_df(spark, TimestampGen(start=datetime(5, 1, 1, tzinfo=timezone.utc), end=datetime(15, 1, 1, tzinfo=timezone.utc)), seed=1)
.selectExpr("a + (interval {} days {} seconds)".format(days, seconds)))

@pytest.mark.parametrize('data_gen', date_gens, ids=idfn)
def test_datediff(data_gen):
assert_gpu_and_cpu_are_equal_collect(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,30 +102,7 @@ class Spark310Shims extends Spark301Shims {
}

override def getExprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = {
tgravescs marked this conversation as resolved.
Show resolved Hide resolved
val exprs310: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = Seq(
GpuOverrides.expr[TimeAdd](
"Subtracts interval from timestamp",
(a, conf, p, r) => new BinaryExprMeta[TimeAdd](a, conf, p, r) {
override def tagExprForGpu(): Unit = {
a.interval match {
case Literal(intvl: CalendarInterval, DataTypes.CalendarIntervalType) =>
if (intvl.months != 0) {
willNotWorkOnGpu("interval months isn't supported")
}
case _ =>
willNotWorkOnGpu("only literals are supported for intervals")
}
if (ZoneId.of(a.timeZoneId.get).normalized() != GpuOverrides.UTC_TIMEZONE_ID) {
willNotWorkOnGpu("Only UTC zone id is supported")
}
}

override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
GpuTimeSub(lhs, rhs)
}
)
).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap
exprs310 ++ super.exprs301
super.exprs301
}

override def getExecs: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ import org.apache.spark.sql.rapids._
import org.apache.spark.sql.rapids.catalyst.expressions.GpuRand
import org.apache.spark.sql.rapids.execution.{GpuBroadcastMeta, GpuBroadcastNestedLoopJoinMeta, GpuCustomShuffleReaderExec, GpuShuffleMeta}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}

/**
* Base class for all ReplacementRules
Expand Down Expand Up @@ -907,6 +907,28 @@ object GpuOverrides {
GpuDateDiff(lhs, rhs)
}
}),
expr[TimeAdd](
"Adds interval from timestamp",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interval "to" timestamp

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

(a, conf, p, r) => new BinaryExprMeta[TimeAdd](a, conf, p, r) {
override def tagExprForGpu(): Unit = {
a.interval match {
case Literal(intvl: CalendarInterval, DataTypes.CalendarIntervalType) =>
if (intvl.months != 0) {
willNotWorkOnGpu("interval months isn't supported")
}
case _ =>
willNotWorkOnGpu("only literals are supported for intervals")
}
if (ZoneId.of(a.timeZoneId.get).normalized() != GpuOverrides.UTC_TIMEZONE_ID) {
willNotWorkOnGpu("Only UTC zone id is supported")
}
}

override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = {
GpuTimeAdd(lhs, rhs)
}
}
),
expr[ToUnixTimestamp](
"Returns the UNIX timestamp of the given time",
(a, conf, p, r) => new UnixTimeExprMeta[ToUnixTimestamp](a, conf, p, r){
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,12 @@ case class GpuYear(child: Expression) extends GpuDateUnaryExpression {
GpuColumnVector.from(input.getBase.year())
}

case class GpuTimeSub(
abstract class GpuTimeMath(
start: Expression,
interval: Expression,
timeZoneId: Option[String] = None)
extends BinaryExpression with GpuExpression with TimeZoneAwareExpression with ExpectsInputTypes {
extends BinaryExpression with GpuExpression with TimeZoneAwareExpression with ExpectsInputTypes
with Serializable {

def this(start: Expression, interval: Expression) = this(start, interval, None)

Expand All @@ -136,10 +137,6 @@ case class GpuTimeSub(

override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess

override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = {
tgravescs marked this conversation as resolved.
Show resolved Hide resolved
copy(timeZoneId = Option(timeZoneId))
}

override def columnarEval(batch: ColumnarBatch): Any = {
var lhs: Any = null
var rhs: Any = null
Expand All @@ -156,7 +153,7 @@ case class GpuTimeSub(
if (usToSub != 0) {
withResource(Scalar.fromLong(usToSub)) { us_s =>
withResource(l.getBase.castTo(DType.INT64)) { us =>
withResource(us.sub(us_s)) {longResult =>
withResource(intervalMath(us_s, us)) { longResult =>
GpuColumnVector.from(longResult.castTo(DType.TIMESTAMP_MICROSECONDS))
}
}
Expand All @@ -177,6 +174,36 @@ case class GpuTimeSub(
}
}
}

def intervalMath(us_s: Scalar, us: ColumnVector): ColumnVector
}

case class GpuTimeAdd(start: Expression,
interval: Expression,
timeZoneId: Option[String] = None)
extends GpuTimeMath(start, interval, timeZoneId) {

override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = {
copy(timeZoneId = Option(timeZoneId))
}

override def intervalMath(us_s: Scalar, us: ColumnVector): ColumnVector = {
us.add(us_s)
tgravescs marked this conversation as resolved.
Show resolved Hide resolved
}
}

case class GpuTimeSub(start: Expression,
interval: Expression,
timeZoneId: Option[String] = None)
extends GpuTimeMath(start, interval, timeZoneId) {

override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = {
copy(timeZoneId = Option(timeZoneId))
}

def intervalMath(us_s: Scalar, us: ColumnVector): ColumnVector = {
us.sub(us_s)
}
}

case class GpuDateDiff(endDate: Expression, startDate: Expression)
Expand Down