Skip to content

Commit

Permalink
Support multiplication on ANSI interval types (#5105)
Browse files Browse the repository at this point in the history
Signed-off-by: Chong Gao <res_life@163.com>
  • Loading branch information
res-life authored Apr 14, 2022
1 parent a8d919b commit 6c27068
Show file tree
Hide file tree
Showing 6 changed files with 765 additions and 6 deletions.
8 changes: 8 additions & 0 deletions integration_tests/src/main/python/arithmetic_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,3 +946,11 @@ def test_subtraction_overflow_with_ansi_enabled_day_time_interval(ansi_enabled):
def test_unary_positive_day_time_interval():
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, DayTimeIntervalGen()).selectExpr('+a'))

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
@pytest.mark.parametrize('data_gen', _no_overflow_multiply_gens + [DoubleGen(min_exp=-3, max_exp=5, special_cases=[0.0])], ids=idfn)
def test_day_time_interval_multiply_number(data_gen):
gen_list = [('_c1', DayTimeIntervalGen(min_value=timedelta(seconds=-20 * 86400), max_value=timedelta(seconds=20 * 86400))),
('_c2', data_gen)]
assert_gpu_and_cpu_are_equal_collect(
lambda spark: gen_df(spark, gen_list).selectExpr("_c1 * _c2"))
Original file line number Diff line number Diff line change
Expand Up @@ -250,9 +250,9 @@ trait Spark320PlusShims extends SparkShims with RebaseShims with Logging {
ExprChecks.projectAndAst(
TypeSig.astTypes,
(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.CALENDAR
+ TypeSig.ARRAY + TypeSig.MAP + TypeSig.STRUCT + TypeSig.DAYTIME)
.nested(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 +
TypeSig.ARRAY + TypeSig.MAP + TypeSig.STRUCT),
+ TypeSig.ARRAY + TypeSig.MAP + TypeSig.STRUCT + TypeSig.ansiIntervals)
.nested(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 +
TypeSig.ARRAY + TypeSig.MAP + TypeSig.STRUCT),
TypeSig.all),
(lit, conf, p, r) => new LiteralExprMeta(lit, conf, p, r)),
GpuOverrides.expr[TimeAdd](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ object GpuTypeShims {
*/
def supportToScalarForType(t: DataType): Boolean = {
t match {
case _: YearMonthIntervalType => true
case _: DayTimeIntervalType => true
case _ => false
}
Expand All @@ -143,8 +144,13 @@ object GpuTypeShims {
/**
* Convert the given value to Scalar
*/
def toScalarForType(t: DataType, v: Any) = {
def toScalarForType(t: DataType, v: Any): Scalar = {
t match {
case _: YearMonthIntervalType => v match {
case i: Int => Scalar.fromInt(i)
case _ => throw new IllegalArgumentException(s"'$v: ${v.getClass}' is not supported" +
s" for IntType, expecting int")
}
case _: DayTimeIntervalType => v match {
case l: Long => Scalar.fromLong(l)
case _ => throw new IllegalArgumentException(s"'$v: ${v.getClass}' is not supported" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids._
import org.apache.spark.sql.rapids.shims.GpuTimeAdd
import org.apache.spark.sql.rapids.shims.{GpuMultiplyDTInterval, GpuMultiplyYMInterval, GpuTimeAdd}
import org.apache.spark.sql.types.{CalendarIntervalType, DayTimeIntervalType, DecimalType, StructType}
import org.apache.spark.unsafe.types.CalendarInterval

Expand Down Expand Up @@ -181,7 +181,29 @@ trait Spark33XShims extends Spark321PlusShims with Spark320PlusNonDBShims {

// ANSI support for ABS was added in 3.2.0 SPARK-33275
override def convertToGpu(child: Expression): GpuExpression = GpuAbs(child, ansiEnabled)
})
}),
GpuOverrides.expr[MultiplyYMInterval](
"Year-month interval * number",
ExprChecks.binaryProject(
TypeSig.YEARMONTH,
TypeSig.YEARMONTH,
("lhs", TypeSig.YEARMONTH, TypeSig.YEARMONTH),
("rhs", TypeSig.gpuNumeric - TypeSig.DECIMAL_128, TypeSig.gpuNumeric)),
(a, conf, p, r) => new BinaryExprMeta[MultiplyYMInterval](a, conf, p, r) {
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
GpuMultiplyYMInterval(lhs, rhs)
}),
GpuOverrides.expr[MultiplyDTInterval](
"Day-time interval * number",
ExprChecks.binaryProject(
TypeSig.DAYTIME,
TypeSig.DAYTIME,
("lhs", TypeSig.DAYTIME, TypeSig.DAYTIME),
("rhs", TypeSig.gpuNumeric - TypeSig.DECIMAL_128, TypeSig.gpuNumeric)),
(a, conf, p, r) => new BinaryExprMeta[MultiplyDTInterval](a, conf, p, r) {
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
GpuMultiplyDTInterval(lhs, rhs)
})
).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap
super.getExprs ++ map
}
Expand Down
Loading

0 comments on commit 6c27068

Please sign in to comment.