Skip to content

Commit

Permalink
[SPARK-36924][SQL] CAST between ANSI intervals and IntegralType
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Add cast `AnsiIntervalType` to `IntegralType`

requirement:
1. `YearMonthIntervalType` just have one unit
2. `DayTimeIntervalType` just have one unit

cast rule:
1. The value corresponding to the unit of `YearMonthIntervalType` is the value of the `IntegralType` after conversion.
2. The value corresponding to the unit of `DayTimeIntervalType` is the value of the `IntegralType` after conversion.

Add cast `IntegralType` to `AnsiIntervalType`
requirement:

1. `YearMonthIntervalType` just have one unit
2. `DayTimeIntervalType` just have one unit

cast rule:

1. The value of the IntegralTypeis the value of  `YearMonthIntervalType` that with the single unit after conversion.
2. The value of the IntegralTypeis the value of  `DayTimeIntervalType` that with the single unit after conversion.

### Why are the changes needed?
According to 2011 Standards
![截图](https://user-images.githubusercontent.com/41178002/140504037-b86793f0-2c97-49f7-bcbf-bb6864592aa8.PNG)

7) If TD is an interval and SD is exact numeric, then TD shall contain only a single <primary datetime field>.
8) If TD is exact numeric and SD is an interval, then SD shall contain only a single <primary datetime field>.

### Does this PR introduce _any_ user-facing change?
Yes, user can use cast function between YearMonthIntervalType to NumericType

### How was this patch tested?
add ut testcase

Closes #34494 from Peng-Lei/SPARK-36924.

Authored-by: PengLei <peng.8lei@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
Peng-Lei authored and cloud-fan committed Nov 17, 2021
1 parent 785ca85 commit 9553ed7
Show file tree
Hide file tree
Showing 3 changed files with 496 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
import org.apache.spark.sql.catalyst.util.IntervalStringStyles.ANSI_STYLE
import org.apache.spark.sql.catalyst.util.IntervalUtils.{dayTimeIntervalToByte, dayTimeIntervalToInt, dayTimeIntervalToLong, dayTimeIntervalToShort, yearMonthIntervalToByte, yearMonthIntervalToInt, yearMonthIntervalToShort}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -81,9 +82,13 @@ object Cast {
case (StringType, CalendarIntervalType) => true
case (StringType, _: DayTimeIntervalType) => true
case (StringType, _: YearMonthIntervalType) => true
case (_: IntegralType, DayTimeIntervalType(s, e)) if s == e => true
case (_: IntegralType, YearMonthIntervalType(s, e)) if s == e => true

case (_: DayTimeIntervalType, _: DayTimeIntervalType) => true
case (_: YearMonthIntervalType, _: YearMonthIntervalType) => true
case (_: DayTimeIntervalType, _: IntegralType) => true
case (_: YearMonthIntervalType, _: IntegralType) => true

case (StringType, _: NumericType) => true
case (BooleanType, _: NumericType) => true
Expand Down Expand Up @@ -589,6 +594,15 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
IntervalUtils.castStringToDTInterval(s, it.startField, it.endField))
case _: DayTimeIntervalType => buildCast[Long](_, s =>
IntervalUtils.durationToMicros(IntervalUtils.microsToDuration(s), it.endField))
case x: IntegralType =>
assert(it.startField == it.endField)
if (x == LongType) {
b => IntervalUtils.longToDayTimeInterval(
x.integral.asInstanceOf[Integral[Any]].toLong(b), it.endField)
} else {
b => IntervalUtils.intToDayTimeInterval(
x.integral.asInstanceOf[Integral[Any]].toInt(b), it.endField)
}
}

private[this] def castToYearMonthInterval(
Expand All @@ -598,6 +612,15 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
IntervalUtils.castStringToYMInterval(s, it.startField, it.endField))
case _: YearMonthIntervalType => buildCast[Int](_, s =>
IntervalUtils.periodToMonths(IntervalUtils.monthsToPeriod(s), it.endField))
case x: IntegralType =>
assert(it.startField == it.endField)
if (x == LongType) {
b => IntervalUtils.longToYearMonthInterval(
x.integral.asInstanceOf[Integral[Any]].toLong(b), it.endField)
} else {
b => IntervalUtils.intToYearMonthInterval(
x.integral.asInstanceOf[Integral[Any]].toInt(b), it.endField)
}
}

// LongConverter
Expand All @@ -617,6 +640,10 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
b => x.exactNumeric.asInstanceOf[Numeric[Any]].toLong(b)
case x: NumericType =>
b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b)
case x: DayTimeIntervalType =>
buildCast[Long](_, i => dayTimeIntervalToLong(i, x.startField, x.endField))
case x: YearMonthIntervalType =>
buildCast[Int](_, i => yearMonthIntervalToInt(i, x.startField, x.endField).toLong)
}

// IntConverter
Expand Down Expand Up @@ -645,6 +672,10 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
b => x.exactNumeric.asInstanceOf[Numeric[Any]].toInt(b)
case x: NumericType =>
b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b)
case x: DayTimeIntervalType =>
buildCast[Long](_, i => dayTimeIntervalToInt(i, x.startField, x.endField))
case x: YearMonthIntervalType =>
buildCast[Int](_, i => yearMonthIntervalToInt(i, x.startField, x.endField))
}

// ShortConverter
Expand Down Expand Up @@ -688,6 +719,10 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
}
case x: NumericType =>
b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort
case x: DayTimeIntervalType =>
buildCast[Long](_, i => dayTimeIntervalToShort(i, x.startField, x.endField))
case x: YearMonthIntervalType =>
buildCast[Int](_, i => yearMonthIntervalToShort(i, x.startField, x.endField))
}

// ByteConverter
Expand Down Expand Up @@ -731,6 +766,10 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
}
case x: NumericType =>
b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte
case x: DayTimeIntervalType =>
buildCast[Long](_, i => dayTimeIntervalToByte(i, x.startField, x.endField))
case x: YearMonthIntervalType =>
buildCast[Int](_, i => yearMonthIntervalToByte(i, x.startField, x.endField))
}

/**
Expand Down Expand Up @@ -1502,6 +1541,20 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
code"""
$evPrim = $util.durationToMicros($util.microsToDuration($c), (byte)${it.endField});
"""
case x: IntegralType =>
assert(it.startField == it.endField)
val util = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
if (x == LongType) {
(c, evPrim, _) =>
code"""
$evPrim = $util.longToDayTimeInterval($c, (byte)${it.endField});
"""
} else {
(c, evPrim, _) =>
code"""
$evPrim = $util.intToDayTimeInterval($c, (byte)${it.endField});
"""
}
}

private[this] def castToYearMonthIntervalCode(
Expand All @@ -1519,6 +1572,20 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
code"""
$evPrim = $util.periodToMonths($util.monthsToPeriod($c), (byte)${it.endField});
"""
case x: IntegralType =>
assert(it.startField == it.endField)
val util = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
if (x == LongType) {
(c, evPrim, _) =>
code"""
$evPrim = $util.longToYearMonthInterval($c, (byte)${it.endField});
"""
} else {
(c, evPrim, _) =>
code"""
$evPrim = $util.intToYearMonthInterval($c, (byte)${it.endField});
"""
}
}

private[this] def decimalToTimestampCode(d: ExprValue): Block = {
Expand Down Expand Up @@ -1580,6 +1647,28 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
}
}

private[this] def castDayTimeIntervalToIntegralTypeCode(
startField: Byte,
endField: Byte,
integralType: String): CastFunction = {
val util = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
(c, evPrim, _) =>
code"""
$evPrim = $util.dayTimeIntervalTo$integralType($c, (byte)$startField, (byte)$endField);
"""
}

private[this] def castYearMonthIntervalToIntegralTypeCode(
startField: Byte,
endField: Byte,
integralType: String): CastFunction = {
val util = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
(c, evPrim, _) =>
code"""
$evPrim = $util.yearMonthIntervalTo$integralType($c, (byte)$startField, (byte)$endField);
"""
}

private[this] def castDecimalToIntegralTypeCode(
ctx: CodegenContext,
integralType: String,
Expand Down Expand Up @@ -1664,6 +1753,10 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
castFractionToIntegralTypeCode("byte", ByteType.catalogString)
case x: NumericType =>
(c, evPrim, evNull) => code"$evPrim = (byte) $c;"
case x: DayTimeIntervalType =>
castDayTimeIntervalToIntegralTypeCode(x.startField, x.endField, "Byte")
case x: YearMonthIntervalType =>
castYearMonthIntervalToIntegralTypeCode(x.startField, x.endField, "Byte")
}

private[this] def castToShortCode(
Expand Down Expand Up @@ -1696,6 +1789,10 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
castFractionToIntegralTypeCode("short", ShortType.catalogString)
case x: NumericType =>
(c, evPrim, evNull) => code"$evPrim = (short) $c;"
case x: DayTimeIntervalType =>
castDayTimeIntervalToIntegralTypeCode(x.startField, x.endField, "Short")
case x: YearMonthIntervalType =>
castYearMonthIntervalToIntegralTypeCode(x.startField, x.endField, "Short")
}

private[this] def castToIntCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
Expand Down Expand Up @@ -1726,6 +1823,10 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
castFractionToIntegralTypeCode("int", IntegerType.catalogString)
case x: NumericType =>
(c, evPrim, evNull) => code"$evPrim = (int) $c;"
case x: DayTimeIntervalType =>
castDayTimeIntervalToIntegralTypeCode(x.startField, x.endField, "Int")
case x: YearMonthIntervalType =>
castYearMonthIntervalToIntegralTypeCode(x.startField, x.endField, "Int")
}

private[this] def castToLongCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
Expand Down Expand Up @@ -1755,6 +1856,10 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
castFractionToIntegralTypeCode("long", LongType.catalogString)
case x: NumericType =>
(c, evPrim, evNull) => code"$evPrim = (long) $c;"
case x: DayTimeIntervalType =>
castDayTimeIntervalToIntegralTypeCode(x.startField, x.endField, "Long")
case x: YearMonthIntervalType =>
castYearMonthIntervalToIntegralTypeCode(x.startField, x.endField, "Int")
}

private[this] def castToFloatCode(from: DataType, ctx: CodegenContext): CastFunction = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.{DayTimeIntervalType => DT, Decimal, YearMonthIntervalType => YM}
import org.apache.spark.sql.types.DayTimeIntervalType.{DAY, HOUR, MINUTE, SECOND}
import org.apache.spark.sql.types.YearMonthIntervalType.{MONTH, YEAR}
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}

// The style of textual representation of intervals
Expand Down Expand Up @@ -1253,4 +1255,120 @@ object IntervalUtils {
}
intervalString
}

def intToYearMonthInterval(v: Int, endField: Byte): Int = {
endField match {
case YEAR =>
try {
Math.multiplyExact(v, MONTHS_PER_YEAR)
} catch {
case _: ArithmeticException =>
throw QueryExecutionErrors.castingCauseOverflowError(v, YM(endField).catalogString)
}
case MONTH => v
}
}

def longToYearMonthInterval(v: Long, endField: Byte): Int = {
val vInt = v.toInt
if (v != vInt) {
throw QueryExecutionErrors.castingCauseOverflowError(v, YM(endField).catalogString)
}
intToYearMonthInterval(vInt, endField)
}

def yearMonthIntervalToInt(v: Int, startField: Byte, endField: Byte): Int = {
endField match {
case YEAR => v / MONTHS_PER_YEAR
case MONTH => v
}
}

def yearMonthIntervalToShort(v: Int, startField: Byte, endField: Byte): Short = {
val vInt = yearMonthIntervalToInt(v, startField, endField)
val vShort = vInt.toShort
if (vInt != vShort) {
throw QueryExecutionErrors.castingCauseOverflowError(
toYearMonthIntervalString(v, ANSI_STYLE, startField, endField), ShortType.catalogString)
}
vShort
}

def yearMonthIntervalToByte(v: Int, startField: Byte, endField: Byte): Byte = {
val vInt = yearMonthIntervalToInt(v, startField, endField)
val vByte = vInt.toByte
if (vInt != vByte) {
throw QueryExecutionErrors.castingCauseOverflowError(
toYearMonthIntervalString(v, ANSI_STYLE, startField, endField), ByteType.catalogString)
}
vByte
}

def intToDayTimeInterval(v: Int, endField: Byte): Long = {
endField match {
case DAY =>
try {
Math.multiplyExact(v, MICROS_PER_DAY)
} catch {
case _: ArithmeticException =>
throw QueryExecutionErrors.castingCauseOverflowError(v, DT(endField).catalogString)
}
case HOUR => v * MICROS_PER_HOUR
case MINUTE => v * MICROS_PER_MINUTE
case SECOND => v * MICROS_PER_SECOND
}
}

def longToDayTimeInterval(v: Long, endField: Byte): Long = {
try {
endField match {
case DAY => Math.multiplyExact(v, MICROS_PER_DAY)
case HOUR => Math.multiplyExact(v, MICROS_PER_HOUR)
case MINUTE => Math.multiplyExact(v, MICROS_PER_MINUTE)
case SECOND => Math.multiplyExact(v, MICROS_PER_SECOND)
}
} catch {
case _: ArithmeticException =>
throw QueryExecutionErrors.castingCauseOverflowError(v, DT(endField).catalogString)
}
}

def dayTimeIntervalToLong(v: Long, startField: Byte, endField: Byte): Long = {
endField match {
case DAY => v / MICROS_PER_DAY
case HOUR => v / MICROS_PER_HOUR
case MINUTE => v / MICROS_PER_MINUTE
case SECOND => v / MICROS_PER_SECOND
}
}

def dayTimeIntervalToInt(v: Long, startField: Byte, endField: Byte): Int = {
val vLong = dayTimeIntervalToLong(v, startField, endField)
val vInt = vLong.toInt
if (vLong != vInt) {
throw QueryExecutionErrors.castingCauseOverflowError(
toDayTimeIntervalString(v, ANSI_STYLE, startField, endField), IntegerType.catalogString)
}
vInt
}

def dayTimeIntervalToShort(v: Long, startField: Byte, endField: Byte): Short = {
val vLong = dayTimeIntervalToLong(v, startField, endField)
val vShort = vLong.toShort
if (vLong != vShort) {
throw QueryExecutionErrors.castingCauseOverflowError(
toDayTimeIntervalString(v, ANSI_STYLE, startField, endField), ShortType.catalogString)
}
vShort
}

def dayTimeIntervalToByte(v: Long, startField: Byte, endField: Byte): Byte = {
val vLong = dayTimeIntervalToLong(v, startField, endField)
val vByte = vLong.toByte
if (vLong != vByte) {
throw QueryExecutionErrors.castingCauseOverflowError(
toDayTimeIntervalString(v, ANSI_STYLE, startField, endField), ByteType.catalogString)
}
vByte
}
}
Loading

0 comments on commit 9553ed7

Please sign in to comment.