Skip to content

Commit

Permalink
refactor code
Browse files Browse the repository at this point in the history
Signed-off-by: Niranjan Artal <nartal@nvidia.com>
  • Loading branch information
nartal1 committed Mar 10, 2021
1 parent e99680b commit e969ad9
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package com.nvidia.spark.rapids.shims.spark300

import java.nio.ByteBuffer
import java.time.ZoneId

import scala.collection.mutable.ListBuffer

Expand Down Expand Up @@ -246,22 +245,21 @@ class Spark300Shims extends SparkShims {
("start", TypeSig.TIMESTAMP, TypeSig.TIMESTAMP),
("interval", TypeSig.lit(TypeEnum.CALENDAR)
.withPsNote(TypeEnum.CALENDAR, "months not supported"), TypeSig.CALENDAR)),
(timeSub, conf, p, r) =>
new BinaryExprMeta[TimeSub](timeSub, conf, p, r) with TimeZoneCheck {
override def tagExprForGpu(): Unit = {
timeSub.interval match {
case Literal(intvl: CalendarInterval, DataTypes.CalendarIntervalType) =>
if (intvl.months != 0) {
willNotWorkOnGpu("interval months isn't supported")
}
case _ =>
}
checkTimeZoneId(timeSub.timeZoneId, this)
(timeSub, conf, p, r) => new BinaryExprMeta[TimeSub](timeSub, conf, p, r) {
override def tagExprForGpu(): Unit = {
timeSub.interval match {
case Literal(intvl: CalendarInterval, DataTypes.CalendarIntervalType) =>
if (intvl.months != 0) {
willNotWorkOnGpu("interval months isn't supported")
}
case _ =>
}
checkTimeZoneId(timeSub.timeZoneId)
}

override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
GpuTimeSub(lhs, rhs)
}),
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
GpuTimeSub(lhs, rhs)
}),
GpuOverrides.expr[First](
"first aggregate operator",
ExprChecks.aggNotWindow(TypeSig.commonCudfTypes + TypeSig.NULL, TypeSig.all,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -408,19 +408,6 @@ trait GpuOverridesListener {
costOptimizations: Seq[Optimization])
}

/**
* trait to check if the zone id is in UTC
*/
trait TimeZoneCheck {
def checkTimeZoneId(timeZoneId: Option[String], meta: RapidsMeta[_, _, _]): Unit = {
timeZoneId.foreach { zoneId =>
if (ZoneId.of(zoneId).normalized() != GpuOverrides.UTC_TIMEZONE_ID) {
meta.willNotWorkOnGpu(s"Only UTC zone id is supported. Actual zone id: $zoneId")
}
}
}
}

object GpuOverrides {
val FLOAT_DIFFERS_GROUP_INCOMPAT =
"when enabling these, there may be extra groups produced for floating point grouping " +
Expand Down Expand Up @@ -1327,15 +1314,15 @@ object GpuOverrides {
("interval", TypeSig.lit(TypeEnum.CALENDAR)
.withPsNote(TypeEnum.CALENDAR, "month intervals are not supported"),
TypeSig.CALENDAR)),
(timeAdd, conf, p, r) => new BinaryExprMeta[TimeAdd](timeAdd, conf, p, r) with TimeZoneCheck {
(timeAdd, conf, p, r) => new BinaryExprMeta[TimeAdd](timeAdd, conf, p, r) {
override def tagExprForGpu(): Unit = {
GpuOverrides.extractLit(timeAdd.interval).foreach { lit =>
val intvl = lit.value.asInstanceOf[CalendarInterval]
if (intvl.months != 0) {
willNotWorkOnGpu("interval months isn't supported")
}
}
checkTimeZoneId(timeAdd.timeZoneId, this)
checkTimeZoneId(timeAdd.timeZoneId)
}

override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
Expand All @@ -1349,15 +1336,15 @@ object GpuOverrides {
.withPsNote(TypeEnum.CALENDAR, "month intervals are not supported"),
TypeSig.CALENDAR)),
(dateAddInterval, conf, p, r) =>
new BinaryExprMeta[DateAddInterval](dateAddInterval, conf, p, r) with TimeZoneCheck {
new BinaryExprMeta[DateAddInterval](dateAddInterval, conf, p, r) {
override def tagExprForGpu(): Unit = {
GpuOverrides.extractLit(dateAddInterval.interval).foreach { lit =>
val intvl = lit.value.asInstanceOf[CalendarInterval]
if (intvl.months != 0) {
willNotWorkOnGpu("interval months isn't supported")
}
}
checkTimeZoneId(dateAddInterval.timeZoneId, this)
checkTimeZoneId(dateAddInterval.timeZoneId)
}

override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
Expand Down Expand Up @@ -1405,9 +1392,9 @@ object GpuOverrides {
"Returns the hour component of the string/timestamp",
ExprChecks.unaryProjectNotLambda(TypeSig.INT, TypeSig.INT,
TypeSig.TIMESTAMP, TypeSig.TIMESTAMP),
(hour, conf, p, r) => new UnaryExprMeta[Hour](hour, conf, p, r) with TimeZoneCheck {
(hour, conf, p, r) => new UnaryExprMeta[Hour](hour, conf, p, r) {
override def tagExprForGpu(): Unit = {
checkTimeZoneId(hour.timeZoneId, this)
checkTimeZoneId(hour.timeZoneId)
}

override def convertToGpu(expr: Expression): GpuExpression = GpuHour(expr)
Expand All @@ -1416,9 +1403,9 @@ object GpuOverrides {
"Returns the minute component of the string/timestamp",
ExprChecks.unaryProjectNotLambda(TypeSig.INT, TypeSig.INT,
TypeSig.TIMESTAMP, TypeSig.TIMESTAMP),
(minute, conf, p, r) => new UnaryExprMeta[Minute](minute, conf, p, r) with TimeZoneCheck {
(minute, conf, p, r) => new UnaryExprMeta[Minute](minute, conf, p, r) {
override def tagExprForGpu(): Unit = {
checkTimeZoneId(minute.timeZoneId, this)
checkTimeZoneId(minute.timeZoneId)
}

override def convertToGpu(expr: Expression): GpuExpression =
Expand All @@ -1428,9 +1415,9 @@ object GpuOverrides {
"Returns the second component of the string/timestamp",
ExprChecks.unaryProjectNotLambda(TypeSig.INT, TypeSig.INT,
TypeSig.TIMESTAMP, TypeSig.TIMESTAMP),
(second, conf, p, r) => new UnaryExprMeta[Second](second, conf, p, r) with TimeZoneCheck {
(second, conf, p, r) => new UnaryExprMeta[Second](second, conf, p, r) {
override def tagExprForGpu(): Unit = {
checkTimeZoneId(second.timeZoneId, this)
checkTimeZoneId(second.timeZoneId)
}

override def convertToGpu(expr: Expression): GpuExpression =
Expand Down
10 changes: 10 additions & 0 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package com.nvidia.spark.rapids

import java.time.ZoneId

import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, ComplexTypeMergingExpression, Expression, LambdaFunction, String2TrimExpression, TernaryExpression, UnaryExpression, WindowExpression, WindowFunction}
Expand Down Expand Up @@ -325,6 +327,14 @@ abstract class RapidsMeta[INPUT <: BASE, BASE, OUTPUT <: BASE](
}
}

protected def checkTimeZoneId(timeZoneId: Option[String]): Unit = {
timeZoneId.foreach { zoneId =>
if (ZoneId.of(zoneId).normalized() != GpuOverrides.UTC_TIMEZONE_ID) {
willNotWorkOnGpu(s"Only UTC zone id is supported. Actual zone id: $zoneId")
}
}
}

/**
* Create a string representation of this in append.
* @param strBuilder where to place the string representation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@

package org.apache.spark.sql.rapids

import java.time.ZoneId
import java.util.concurrent.TimeUnit

import ai.rapids.cudf.{BinaryOp, ColumnVector, ColumnView, DType, Scalar}
import com.nvidia.spark.rapids.{Arm, BinaryExprMeta, DataFromReplacementRule, DateUtils, GpuBinaryExpression, GpuColumnVector, GpuExpression, GpuOverrides, GpuScalar, GpuUnaryExpression, RapidsConf, RapidsMeta, TimeZoneCheck}
import com.nvidia.spark.rapids.{Arm, BinaryExprMeta, DataFromReplacementRule, DateUtils, GpuBinaryExpression, GpuColumnVector, GpuExpression, GpuOverrides, GpuScalar, GpuUnaryExpression, RapidsConf, RapidsMeta}
import com.nvidia.spark.rapids.DateUtils.TimestampFormatConversionException
import com.nvidia.spark.rapids.GpuOverrides.{extractStringLit, getTimeParserPolicy}
import com.nvidia.spark.rapids.RapidsPluginImplicits._
Expand Down Expand Up @@ -342,11 +341,11 @@ abstract class UnixTimeExprMeta[A <: BinaryExpression with TimeZoneAwareExpressi
(expr: A, conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: DataFromReplacementRule)
extends BinaryExprMeta[A](expr, conf, parent, rule) with TimeZoneCheck {
extends BinaryExprMeta[A](expr, conf, parent, rule) {
var sparkFormat: String = _
var strfFormat: String = _
override def tagExprForGpu(): Unit = {
checkTimeZoneId(expr.timeZoneId, this)
checkTimeZoneId(expr.timeZoneId)

// Date and Timestamp work too
if (expr.right.dataType == StringType) {
Expand Down

0 comments on commit e969ad9

Please sign in to comment.