diff --git a/integration_tests/src/main/python/arithmetic_ops_test.py b/integration_tests/src/main/python/arithmetic_ops_test.py index 2c84b689571..d529bc19c75 100644 --- a/integration_tests/src/main/python/arithmetic_ops_test.py +++ b/integration_tests/src/main/python/arithmetic_ops_test.py @@ -946,3 +946,11 @@ def test_subtraction_overflow_with_ansi_enabled_day_time_interval(ansi_enabled): def test_unary_positive_day_time_interval(): assert_gpu_and_cpu_are_equal_collect( lambda spark: unary_op_df(spark, DayTimeIntervalGen()).selectExpr('+a')) + +@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') +@pytest.mark.parametrize('data_gen', _no_overflow_multiply_gens + [DoubleGen(min_exp=-3, max_exp=5, special_cases=[0.0])], ids=idfn) +def test_day_time_interval_multiply_number(data_gen): + gen_list = [('_c1', DayTimeIntervalGen(min_value=timedelta(seconds=-20 * 86400), max_value=timedelta(seconds=20 * 86400))), + ('_c2', data_gen)] + assert_gpu_and_cpu_are_equal_collect( + lambda spark: gen_df(spark, gen_list).selectExpr("_c1 * _c2")) diff --git a/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala b/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala index 6c364951532..b09eb7ba41a 100644 --- a/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala +++ b/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala @@ -250,9 +250,9 @@ trait Spark320PlusShims extends SparkShims with RebaseShims with Logging { ExprChecks.projectAndAst( TypeSig.astTypes, (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.CALENDAR - + TypeSig.ARRAY + TypeSig.MAP + TypeSig.STRUCT + TypeSig.DAYTIME) - .nested(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + - TypeSig.ARRAY + TypeSig.MAP + TypeSig.STRUCT), + + TypeSig.ARRAY + TypeSig.MAP + TypeSig.STRUCT + TypeSig.ansiIntervals) + .nested(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + + TypeSig.ARRAY + TypeSig.MAP + TypeSig.STRUCT), TypeSig.all), (lit, conf, p, r) => new LiteralExprMeta(lit, conf, p, r)), GpuOverrides.expr[TimeAdd]( diff --git a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/GpuTypeShims.scala b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/GpuTypeShims.scala index 45432da4428..4fb6f8c3fe9 100644 --- a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/GpuTypeShims.scala +++ b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/GpuTypeShims.scala @@ -135,6 +135,7 @@ object GpuTypeShims { */ def supportToScalarForType(t: DataType): Boolean = { t match { + case _: YearMonthIntervalType => true case _: DayTimeIntervalType => true case _ => false } @@ -143,8 +144,13 @@ object GpuTypeShims { /** * Convert the given value to Scalar */ - def toScalarForType(t: DataType, v: Any) = { + def toScalarForType(t: DataType, v: Any): Scalar = { t match { + case _: YearMonthIntervalType => v match { + case i: Int => Scalar.fromInt(i) + case _ => throw new IllegalArgumentException(s"'$v: ${v.getClass}' is not supported" + + s" for IntType, expecting int") + } case _: DayTimeIntervalType => v match { case l: Long => Scalar.fromLong(l) case _ => throw new IllegalArgumentException(s"'$v: ${v.getClass}' is not supported" + diff --git a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/Spark33XShims.scala b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/Spark33XShims.scala index 8dea131480b..e5daa65acfb 100644 --- a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/Spark33XShims.scala +++ b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/Spark33XShims.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.rapids._ -import org.apache.spark.sql.rapids.shims.GpuTimeAdd +import org.apache.spark.sql.rapids.shims.{GpuMultiplyDTInterval, GpuMultiplyYMInterval, GpuTimeAdd} import org.apache.spark.sql.types.{CalendarIntervalType, DayTimeIntervalType, DecimalType, StructType} import org.apache.spark.unsafe.types.CalendarInterval @@ -181,7 +181,29 @@ trait Spark33XShims extends Spark321PlusShims with Spark320PlusNonDBShims { // ANSI support for ABS was added in 3.2.0 SPARK-33275 override def convertToGpu(child: Expression): GpuExpression = GpuAbs(child, ansiEnabled) - }) + }), + GpuOverrides.expr[MultiplyYMInterval]( + "Year-month interval * number", + ExprChecks.binaryProject( + TypeSig.YEARMONTH, + TypeSig.YEARMONTH, + ("lhs", TypeSig.YEARMONTH, TypeSig.YEARMONTH), + ("rhs", TypeSig.gpuNumeric - TypeSig.DECIMAL_128, TypeSig.gpuNumeric)), + (a, conf, p, r) => new BinaryExprMeta[MultiplyYMInterval](a, conf, p, r) { + override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = + GpuMultiplyYMInterval(lhs, rhs) + }), + GpuOverrides.expr[MultiplyDTInterval]( + "Day-time interval * number", + ExprChecks.binaryProject( + TypeSig.DAYTIME, + TypeSig.DAYTIME, + ("lhs", TypeSig.DAYTIME, TypeSig.DAYTIME), + ("rhs", TypeSig.gpuNumeric - TypeSig.DECIMAL_128, TypeSig.gpuNumeric)), + (a, conf, p, r) => new BinaryExprMeta[MultiplyDTInterval](a, conf, p, r) { + override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = + GpuMultiplyDTInterval(lhs, rhs) + }) ).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap super.getExprs ++ map } diff --git a/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/intervalExpressions.scala b/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/intervalExpressions.scala new file mode 100644 index 00000000000..41160f568be --- /dev/null +++ b/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/intervalExpressions.scala @@ -0,0 +1,323 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids.shims + +import ai.rapids.cudf.{BinaryOperable, ColumnVector, DType, RoundMode, Scalar} +import com.nvidia.spark.rapids.{Arm, BoolUtils, GpuBinaryExpression, GpuColumnVector, GpuScalar} + +import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes, NullIntolerant} +import org.apache.spark.sql.types._ + +object IntervalUtils extends Arm { + + /** + * Convert long cv to int cv, throws exception if any value in `longCv` exceeds the int limits. + * Check (int)(long_value) == long_value + */ + def castLongToIntWithOverflowCheck(longCv: ColumnVector): ColumnVector = { + withResource(longCv.castTo(DType.INT32)) { intResult => + withResource(longCv.notEqualTo(intResult)) { notEquals => + if (BoolUtils.isAnyValidTrue(notEquals)) { + throw new ArithmeticException("overflow occurs") + } else { + intResult.incRefCount() + } + } + } + } + + def checkDecimal128CvInRange(decimal128Cv: ColumnVector, minValue: Long, maxValue: Long): Unit = { + // check min + withResource(Scalar.fromLong(minValue)) { minScalar => + withResource(decimal128Cv.lessThan(minScalar)) { lessThanMin => + if (BoolUtils.isAnyValidTrue(lessThanMin)) { + throw new ArithmeticException("overflow occurs") + } + } + } + + // check max + withResource(Scalar.fromLong(maxValue)) { maxScalar => + withResource(decimal128Cv.greaterThan(maxScalar)) { greaterThanMax => + if (BoolUtils.isAnyValidTrue(greaterThanMax)) { + throw new ArithmeticException("overflow occurs") + } + } + } + } + + /** + * Multiple with overflow check, then cast to long + * Equivalent to Math.multiplyExact + * + * @param left cv or scalar + * @param right cv or scalar, will not be scalar if left is scalar + * @return the long result of left * right + */ + def multipleToLongWithOverflowCheck(left: BinaryOperable, right: BinaryOperable): ColumnVector = { + val decimal128Type = DType.create(DType.DTypeEnum.DECIMAL128, 0) + withResource(left.mul(right, decimal128Type)) { ret => + checkDecimal128CvInRange(ret, Long.MinValue, Long.MaxValue) + ret.castTo(DType.INT64) + } + } + + /** + * Multiple with overflow check, then cast to int + * Equivalent to Math.multiplyExact + * + * @param left cv or scalar + * @param right cv or scalar, will not be scalar if left is scalar + * @return the int result of left * right + */ + def multipleToIntWithOverflowCheck(left: BinaryOperable, right: BinaryOperable): ColumnVector = { + val decimal128Type = DType.create(DType.DTypeEnum.DECIMAL128, 0) + withResource(left.mul(right, decimal128Type)) { ret => + checkDecimal128CvInRange(ret, Int.MinValue, Int.MaxValue) + ret.castTo(DType.INT32) + } + } + + def checkDoubleInfNan(doubleCv: ColumnVector): Unit = { + // check infinity + withResource(Scalar.fromDouble(Double.PositiveInfinity)) { positiveInfScalar => + withResource(doubleCv.equalTo(positiveInfScalar)) { equalsInfinity => + if (BoolUtils.isAnyValidTrue(equalsInfinity)) { + throw new ArithmeticException("Has infinity") + } + } + } + + // check -infinity + withResource(Scalar.fromDouble(Double.NegativeInfinity)) { negativeInfScalar => + withResource(doubleCv.equalTo(negativeInfScalar)) { equalsInfinity => + if (BoolUtils.isAnyValidTrue(equalsInfinity)) { + throw new ArithmeticException("Has -infinity") + } + } + } + + // check NaN + withResource(doubleCv.isNan) { isNan => + if (BoolUtils.isAnyValidTrue(isNan)) { + throw new ArithmeticException("Has NaN") + } + } + } + + /** + * Round double cv to int with overflow check + * equivalent to + * com.google.common.math.DoubleMath.roundToInt(double value, RoundingMode.HALF_UP) + */ + def roundDoubleToIntWithOverflowCheck(doubleCv: ColumnVector): ColumnVector = { + // check Inf, -Inf, NaN + checkDoubleInfNan(doubleCv) + + withResource(doubleCv.round(RoundMode.HALF_UP)) { roundedDouble => + // throws exception if the result exceeds int limits + withResource(roundedDouble.castTo(DType.INT64)) { long => + castLongToIntWithOverflowCheck(long) + } + } + } + + /** + * round double cv to long with overflow check + * equivalent to + * com.google.common.math.DoubleMath.roundToInt(double value, RoundingMode.HALF_UP) + */ + def roundDoubleToLongWithOverflowCheck(doubleCv: ColumnVector): ColumnVector = { + // check Inf, -Inf, NaN + checkDoubleInfNan(doubleCv) + + roundToLongWithCheck(doubleCv) + } + + /** + * Check if double cv exceeds long limits + * Rewrite from + * com.google.common.math.DoubleMath.roundToLong: + * z = roundIntermediate(x, mode) + * checkInRange(MIN_LONG_AS_DOUBLE - z < 1.0 & z < MAX_LONG_AS_DOUBLE_PLUS_ONE) + * return z.toLong + */ + def roundToLongWithCheck(doubleCv: ColumnVector): ColumnVector = { + val MIN_LONG_AS_DOUBLE: Double = -9.223372036854776E18 + val MAX_LONG_AS_DOUBLE_PLUS_ONE: Double = 9.223372036854776E18 + + withResource(doubleCv.round(RoundMode.HALF_UP)) { z => + withResource(Scalar.fromDouble(MAX_LONG_AS_DOUBLE_PLUS_ONE)) { max => + withResource(z.greaterOrEqualTo(max)) { invalid => + if (BoolUtils.isAnyValidTrue(invalid)) { + throw new ArithmeticException("Round double to long overflow") + } + } + } + + withResource(Scalar.fromDouble(MIN_LONG_AS_DOUBLE)) { min => + withResource(min.sub(z)) { diff => + withResource(Scalar.fromDouble(1.0d)) { one => + withResource(diff.greaterOrEqualTo(one)) { invalid => + if (BoolUtils.isAnyValidTrue(invalid)) { + throw new ArithmeticException("Round double to long overflow") + } + } + } + } + } + + z.castTo(DType.INT64) + } + } +} + +/** + * Multiply a year-month interval by a numeric: + * year-month interval * number(byte, short, int, long, float, double) + * Note not support year-month interval * decimal + * Year-month interval's internal type is int, the value of int is 12 * year + month + * left expression is interval, right expression is number + * Rewrite from Spark code: + * https://github.com/apache/spark/blob/v3.2.1/sql/catalyst/src/main/scala/ + * org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala#L506 + * + */ +case class GpuMultiplyYMInterval( + interval: Expression, + num: Expression) extends GpuBinaryExpression with ImplicitCastInputTypes with NullIntolerant { + + override def left: Expression = interval + + override def right: Expression = num + + override def doColumnar(interval: GpuColumnVector, numScalar: GpuScalar): ColumnVector = { + doColumnarImp(interval.getBase, numScalar.getBase, num.dataType) + } + + override def doColumnar(interval: GpuColumnVector, num: GpuColumnVector): ColumnVector = { + doColumnarImp(interval.getBase, num.getBase, num.dataType) + } + + override def doColumnar(intervalScalar: GpuScalar, num: GpuColumnVector): ColumnVector = { + doColumnarImp(intervalScalar.getBase, num.getBase, num.dataType) + } + + override def doColumnar(numRows: Int, intervalScalar: GpuScalar, + numScalar: GpuScalar): ColumnVector = { + withResource(GpuColumnVector.from(intervalScalar, numRows, interval.dataType)) { expandedLhs => + doColumnar(expandedLhs, numScalar) + } + } + + private def doColumnarImp(interval: BinaryOperable, numOperable: BinaryOperable, + numType: DataType): ColumnVector = { + numType match { + case ByteType | ShortType | IntegerType => // num is byte, short or int + // compute interval * num + // interval and num are both in the range: [Int.MinValue, Int.MaxValue], + // so long fits the result and no need to check overflow + val longResultCv: ColumnVector = interval.mul(numOperable, DType.INT64) + + withResource(longResultCv) { longResult => + // throws exception if exceeds int limits + IntervalUtils.castLongToIntWithOverflowCheck(longResult) + } + + case LongType => // num is long + // The following is equivalent to Math.toIntExact(Math.multiplyExact(months, num)) + IntervalUtils.multipleToIntWithOverflowCheck(interval, numOperable) + + case FloatType | DoubleType => // num is float or double + val doubleResultCv = interval.mul(numOperable, DType.FLOAT64) + + withResource(doubleResultCv) { doubleResult => + // round to long with overflow check + IntervalUtils.roundDoubleToIntWithOverflowCheck(doubleResult) + } + case _ => throw new IllegalArgumentException( + s"Not support num type $numType in GpuMultiplyYMInterval") + } + } + + override def toString: String = s"$interval * $num" + + override def inputTypes: Seq[AbstractDataType] = Seq(YearMonthIntervalType, NumericType) + + override def dataType: DataType = YearMonthIntervalType() +} + +/** + * Multiply a day-time interval by a numeric + * day-time interval * number(byte, short, int, long, float, double) + * Note not support day-time interval * decimal + * Day-time interval's interval type is long, the value of long is the total microseconds + * Rewrite from Spark code: + * https://github.com/apache/spark/blob/v3.2.1/sql/catalyst/src/main/scala/ + * org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala#L558 + */ +case class GpuMultiplyDTInterval( + interval: Expression, + num: Expression) + extends GpuBinaryExpression with ImplicitCastInputTypes with NullIntolerant { + + override def left: Expression = interval + + override def right: Expression = num + + override def doColumnar(interval: GpuColumnVector, numScalar: GpuScalar): ColumnVector = { + doColumnarImp(interval.getBase, numScalar.getBase, num.dataType) + } + + override def doColumnar(interval: GpuColumnVector, num: GpuColumnVector): ColumnVector = { + doColumnarImp(interval.getBase, num.getBase, num.dataType) + } + + override def doColumnar(intervalScalar: GpuScalar, num: GpuColumnVector): ColumnVector = { + doColumnarImp(intervalScalar.getBase, num.getBase, num.dataType) + } + + override def doColumnar(numRows: Int, intervalScalar: GpuScalar, + numScalar: GpuScalar): ColumnVector = { + withResource(GpuColumnVector.from(intervalScalar, numRows, interval.dataType)) { expandedLhs => + doColumnar(expandedLhs, numScalar) + } + } + + private def doColumnarImp(interval: BinaryOperable, numOperable: BinaryOperable, + numType: DataType): ColumnVector = { + numType match { + case ByteType | ShortType | IntegerType | LongType => + // interval is long type; num is byte, short, int or long + IntervalUtils.multipleToLongWithOverflowCheck(interval, numOperable) + case _: FloatType | DoubleType => + // interval is long type; num is float or double + val doubleResultCv = interval.mul(numOperable, DType.FLOAT64) + withResource(doubleResultCv) { doubleResult => + // round to long with overflow check + IntervalUtils.roundDoubleToLongWithOverflowCheck(doubleResult) + } + case _ => throw new IllegalArgumentException( + s"Not support num type $numType in MultiplyDTInterval") + } + } + + override def inputTypes: Seq[AbstractDataType] = Seq(DayTimeIntervalType, NumericType) + + override def dataType: DataType = DayTimeIntervalType() + +} diff --git a/tests/src/test/330/scala/com/nvidia/spark/rapids/IntervalMultiplySuite.scala b/tests/src/test/330/scala/com/nvidia/spark/rapids/IntervalMultiplySuite.scala new file mode 100644 index 00000000000..461e9964155 --- /dev/null +++ b/tests/src/test/330/scala/com/nvidia/spark/rapids/IntervalMultiplySuite.scala @@ -0,0 +1,400 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import java.time.{Duration, Period} + +import scala.util.Random + +import org.apache.spark.SparkException +import org.apache.spark.sql.Row +import org.apache.spark.sql.types._ + +class IntervalMultiplySuite extends SparkQueryCompareTestSuite { + testSparkResultsAreEqual( + "test year-month interval * integer num, normal case", + spark => { + val data = Seq( + Row(Period.ofMonths(5), 3L, 3), + Row(Period.ofMonths(5), null, null), + Row(null, 5L, 5), + Row(Period.ofMonths(6), 0L, 0), + Row(Period.ofMonths(0), 6L, 6), + Row(Period.ofMonths(0), Long.MinValue, Int.MinValue), + Row(Period.ofMonths(0), Long.MaxValue, Int.MaxValue) + ) + val schema = StructType(Seq( + StructField("c_ym", YearMonthIntervalType()), + StructField("c_l", LongType), + StructField("c_i", IntegerType))) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + } + ) { + df => + df.selectExpr( + "c_ym * c_l", "c_ym * c_i", "c_ym * 15", "c_ym * 15L", "c_ym * 1.15f", "c_ym * 1.15d") + } + + testSparkResultsAreEqual( + "test year-month interval * integer num, normal case 2", + spark => { + val data = Seq( + Row(1.toByte, 1.toShort, 1, 1L, 1.5f, 1.5d) + ) + val schema = StructType(Seq( + StructField("c_b", ByteType), + StructField("c_s", ShortType), + StructField("c_i", IntegerType), + StructField("c_l", LongType), + StructField("c_f", FloatType), + StructField("c_d", DoubleType))) + + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + } + ) { + df => + df.selectExpr( + "interval '5' month * c_b", + "interval '5' month * c_s", + "interval '5' month * c_i", + "interval '5' month * c_l", + "interval '5' month * c_f", + "interval '5' month * c_d" + ) + } + + testSparkResultsAreEqual("test year-month interval * float num, normal case", + spark => { + val r = new Random(0) + val data = (0 until 1024).map(i => { + val sign = if (r.nextBoolean()) 1 else -1 + if (i % 10 == 0) { + Row(Period.ofMonths(r.nextInt(1024)), Float.MinPositiveValue, Double.MinPositiveValue) + } else { + Row(Period.ofMonths(r.nextInt(1024)), sign * r.nextFloat(), sign * r.nextDouble()) + } + }) + + val schema = StructType(Seq( + StructField("c_ym", YearMonthIntervalType()), + StructField("c_f", FloatType), + StructField("c_d", DoubleType))) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + } + ) { + df => df.selectExpr("c_ym * c_f", "c_ym * c_d") + } + + def testOverflowMultipyInt(testCaseName: String, month: Int, num: Int): Unit = { + testBothCpuGpuExpectedException[SparkException](testCaseName, + e => e.getMessage.contains("ArithmeticException"), + spark => { + val data = Seq(Row(Period.ofMonths(month))) + val schema = StructType(Seq(StructField("c1", YearMonthIntervalType()))) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + } + ) { + df => df.selectExpr(s"c1 * $num") + } + } + + testOverflowMultipyInt("test year-month interval * int overflow, case 1", Int.MaxValue, 2) + testOverflowMultipyInt("test year-month interval * int overflow, case 2", 3, Int.MaxValue / 2) + testOverflowMultipyInt("test year-month interval * int overflow, case 3", Int.MinValue, 2) + testOverflowMultipyInt("test year-month interval * int overflow, case 4", -1, Int.MinValue) + testOverflowMultipyInt("test year-month interval * int overflow, case 5", Int.MinValue, -1) + + + def testOverflowMultipyLong(testCaseName: String, month: Int, num: Long): Unit = { + testBothCpuGpuExpectedException[SparkException](testCaseName, + e => e.getMessage.contains("ArithmeticException"), + spark => { + val data = Seq(Row(Period.ofMonths(month))) + val schema = StructType(Seq(StructField("c1", YearMonthIntervalType()))) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + } + ) { + df => df.selectExpr(s"c1 * ${num}L") + } + } + + testOverflowMultipyLong( + "test year-month interval * long overflow case 1", -1, Long.MinValue) + testOverflowMultipyLong( + "test year-month interval * long overflow case 2", Int.MaxValue, -2) + testOverflowMultipyLong( + "test year-month interval * long overflow case 3", 1, Int.MaxValue.toLong + 1L) + testOverflowMultipyLong( + "test year-month interval * long overflow case 4", -1, Int.MinValue.toLong) + testOverflowMultipyLong( + "test year-month interval * long overflow case 5", -1, Long.MinValue) + testOverflowMultipyLong( + "test year-month interval * long overflow case 6", 2, Long.MaxValue) + + def testOverflowMultipyLong2(testCaseName: String, month: Int, num: Long): Unit = { + testBothCpuGpuExpectedException[SparkException](testCaseName, + e => e.getMessage.contains("ArithmeticException"), + spark => { + val data = Seq(Row(Period.ofMonths(month), num)) + val schema = StructType(Seq( + StructField("c1", YearMonthIntervalType()), + StructField("c2", LongType) + )) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + } + ) { + df => df.selectExpr("c1 * c2") + } + } + + testOverflowMultipyLong2( + "test year-month interval * long overflow case 21", -1, Long.MinValue) + testOverflowMultipyLong2( + "test year-month interval * long overflow case 22", Int.MinValue, -1L) + testOverflowMultipyLong2( + "test year-month interval * long overflow case 23", 1, Int.MinValue.toLong - 1L) + testOverflowMultipyLong2( + "test year-month interval * long overflow case 24", 2, Int.MinValue.toLong + 1L) + testOverflowMultipyLong2( + "test year-month interval * long overflow case 25", -2, Long.MaxValue) + + def testOverflowMultipyFloat(testCaseName: String, month: Int, num: Float): Unit = { + testBothCpuGpuExpectedException[SparkException](testCaseName, + e => e.getMessage.contains("ArithmeticException"), + spark => { + val data = Seq(Row(Period.ofMonths(month), num)) + val schema = StructType(Seq( + StructField("c1", YearMonthIntervalType()), + StructField("c2", FloatType) + )) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + } + ) { + df => df.selectExpr("c1 * c2") + } + } + + testOverflowMultipyFloat("test year-month interval * float overflow case 1", 1, Float.NaN) + testOverflowMultipyFloat( + "test year-month interval * float overflow case 2", 1, Float.PositiveInfinity) + testOverflowMultipyFloat( + "test year-month interval * float overflow case 3", 1, Float.NegativeInfinity) + testOverflowMultipyFloat( + "test year-month interval * float overflow case 4", 2, Long.MinValue.toFloat) + testOverflowMultipyFloat( + "test year-month interval * float overflow case 5", 2, Long.MaxValue.toFloat) + + def testOverflowMultipyDouble(testCaseName: String, month: Int, num: Double): Unit = { + testBothCpuGpuExpectedException[SparkException](testCaseName, + e => e.getMessage.contains("ArithmeticException"), + spark => { + val data = Seq(Row(Period.ofMonths(month), num)) + val schema = StructType(Seq( + StructField("c1", YearMonthIntervalType()), + StructField("c2", DoubleType) + )) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + } + ) { + df => df.selectExpr("c1 * c2") + } + } + + testOverflowMultipyDouble("test year-month interval * double overflow case 1", 1, Double.NaN) + testOverflowMultipyDouble( + "test year-month interval * double overflow case 2", 1, Double.PositiveInfinity) + testOverflowMultipyDouble( + "test year-month interval * double overflow case 3", 1, Double.NegativeInfinity) + testOverflowMultipyDouble( + "test year-month interval * double overflow case 4", 2, Long.MinValue.toDouble) + testOverflowMultipyDouble( + "test year-month interval * double overflow case 5", 2, Long.MaxValue.toDouble) + + +// The following are day-time test cases + + def testOverflowDTMultipyInt(testCaseName: String, microSeconds: Long, num: Int): Unit = { + testBothCpuGpuExpectedException[SparkException](testCaseName, + e => e.getMessage.contains("ArithmeticException"), + spark => { + val d = Duration.ofSeconds(microSeconds / 1000000, microSeconds % 1000000 * 1000) + val data = Seq(Row(d)) + val schema = StructType(Seq(StructField("c1", DayTimeIntervalType()))) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + } + ) { + df => df.selectExpr(s"c1 * $num") + } + } + + testOverflowDTMultipyInt("test day-time interval * int overflow, case 1", Long.MaxValue, 2) + testOverflowDTMultipyInt("test day-time interval * int overflow, case 2", Long.MaxValue / 2, 3) + testOverflowDTMultipyInt("test day-time interval * int overflow, case 3", Long.MinValue, 2) + testOverflowDTMultipyInt("test day-time interval * int overflow, case 4", Long.MinValue, -1) + testOverflowDTMultipyInt("test day-time interval * int overflow, case 5", Long.MinValue, -2) + + + def testOverflowDTMultipyLong(testCaseName: String, microSeconds: Long, num: Long): Unit = { + testBothCpuGpuExpectedException[SparkException](testCaseName, + e => e.getMessage.contains("ArithmeticException"), + spark => { + val d = Duration.ofSeconds(microSeconds / 1000000, microSeconds % 1000000 * 1000) + val data = Seq(Row(d)) + val schema = StructType(Seq(StructField("c1", DayTimeIntervalType()))) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + } + ) { + df => df.selectExpr(s"c1 * ${num}L") + } + } + + testOverflowDTMultipyLong( + "test day-time interval * long overflow case 1", 2, Long.MaxValue) + testOverflowDTMultipyLong( + "test day-time interval * long overflow case 2", 3, Long.MinValue) + testOverflowDTMultipyLong( + "test day-time interval * long overflow case 3", Long.MinValue, -1) + testOverflowDTMultipyLong( + "test day-time interval * long overflow case 4", -1L, Long.MinValue) + testOverflowDTMultipyLong( + "test day-time interval * long overflow case 5", 2, Long.MaxValue) + testOverflowDTMultipyLong( + "test day-time interval * long overflow case 6", -2, Long.MaxValue) + + def testOverflowDTMultipyLong2(testCaseName: String, microSeconds: Long, num: Long): Unit = { + testBothCpuGpuExpectedException[SparkException](testCaseName, + e => e.getMessage.contains("ArithmeticException"), + spark => { + val d = Duration.ofSeconds(microSeconds / 1000000, microSeconds % 1000000 * 1000) + val data = Seq(Row(d, num)) + val schema = StructType(Seq( + StructField("c1", DayTimeIntervalType()), + StructField("c2", LongType))) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + } + ) { + df => df.selectExpr("c1 * c2") + } + } + + testOverflowDTMultipyLong2( + "test day-time interval * long overflow case 21", -1L, Long.MinValue) + testOverflowDTMultipyLong2( + "test day-time interval * long overflow case 22", Long.MinValue, -1L) + testOverflowDTMultipyLong2( + "test day-time interval * long overflow case 23", Long.MinValue, 2) + testOverflowDTMultipyLong2( + "test day-time interval * long overflow case 24", -2L, Long.MinValue) + testOverflowDTMultipyLong2( + "test day-time interval * long overflow case 25", 2, Long.MaxValue) + testOverflowDTMultipyLong2( + "test day-time interval * long overflow case 26", -3, Long.MaxValue) + + + def testOverflowDTMultipyFloat(testCaseName: String, microSeconds: Long, num: Float): Unit = { + testBothCpuGpuExpectedException[SparkException](testCaseName, + e => e.getMessage.contains("ArithmeticException"), + spark => { + val d = Duration.ofSeconds(microSeconds / 1000000, microSeconds % 1000000 * 1000) + val data = Seq(Row(d, num)) + val schema = StructType(Seq( + StructField("c1", DayTimeIntervalType()), + StructField("c2", FloatType) + )) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + } + ) { + df => df.selectExpr("c1 * c2") + } + } + + testOverflowDTMultipyFloat( + "test day-time interval * float overflow case 1", 1, Float.NaN) + testOverflowDTMultipyFloat( + "test day-time interval * float overflow case 2", 1, Float.PositiveInfinity) + testOverflowDTMultipyFloat( + "test day-time interval * float overflow case 3", 1, Float.NegativeInfinity) + testOverflowDTMultipyFloat( + "test day-time interval * float overflow case 4", -1, Long.MinValue.toFloat) + testOverflowDTMultipyFloat( + "test day-time interval * float overflow case 5", 2, Long.MaxValue.toFloat) + + def testOverflowDTMultipyDouble(testCaseName: String, microSeconds: Long, num: Double): Unit = { + testBothCpuGpuExpectedException[SparkException](testCaseName, + e => e.getMessage.contains("ArithmeticException"), + spark => { + val d = Duration.ofSeconds(microSeconds / 1000000, microSeconds % 1000000 * 1000) + val data = Seq(Row(d, num)) + val schema = StructType(Seq( + StructField("c1", DayTimeIntervalType()), + StructField("c2", DoubleType) + )) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + } + ) { + df => df.selectExpr("c1 * c2") + } + } + + testOverflowDTMultipyDouble( + "test day-time interval * double overflow case 1", 1, Double.NaN) + testOverflowDTMultipyDouble( + "test day-time interval * double overflow case 2", 1, Double.PositiveInfinity) + testOverflowDTMultipyDouble( + "test day-time interval * double overflow case 3", 1, Double.NegativeInfinity) + testOverflowDTMultipyDouble( + "test day-time interval * double overflow case 4", 1, Float.NegativeInfinity) + testOverflowDTMultipyDouble( + "test day-time interval * double overflow case 5", -1, Long.MinValue.toDouble) + testOverflowDTMultipyDouble( + "test day-time interval * double overflow case 6", 2, Long.MaxValue.toDouble) + testOverflowDTMultipyDouble( + "test day-time interval * double overflow case 7", 3, Long.MinValue.toDouble) + + testSparkResultsAreEqual( + "test day-time interval * integer num, normal case 2", + spark => { + val data = Seq( + Row(Duration.ofSeconds(1), 1.toByte, 1.toShort, 1, 1L, 1.5f, 1.5d) + ) + val schema = StructType(Seq( + StructField("c_dt", DayTimeIntervalType()), + StructField("c_b", ByteType), + StructField("c_s", ShortType), + StructField("c_i", IntegerType), + StructField("c_l", LongType), + StructField("c_f", FloatType), + StructField("c_d", DoubleType))) + + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + } + ) { + df => + df.selectExpr( + "interval '5' second * c_b", + "interval '5' second * c_s", + "interval '5' second * c_i", + "interval '5' second * c_l", + "interval '5' second * c_f", + "interval '5' second * c_d", + "c_dt * 1.5f", + "c_dt * 1.5d", + "c_dt * 5L", + "c_dt * 5" + ) + } +}