Skip to content

Commit

Permalink
code cleanup
Browse files Browse the repository at this point in the history
Signed-off-by: Raza Jafri <rjafri@nvidia.com>
  • Loading branch information
razajafri committed Feb 26, 2021
1 parent 38da4e3 commit 8702bef
Show file tree
Hide file tree
Showing 5 changed files with 238 additions and 313 deletions.
349 changes: 174 additions & 175 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@
package com.nvidia.spark.rapids

import ai.rapids.cudf.{ColumnVector, DType, Scalar}
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import java.text.SimpleDateFormat
import java.time.DateTimeException

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.{Cast, CastBase, Expression, NullIntolerant, TimeZoneAwareExpression}
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -127,179 +125,6 @@ object GpuCast extends Arm {
"required range"

val INVALID_FLOAT_CAST_MSG = "At least one value is either null or is an invalid number"

/**
* Asserts that all values in a column are within the specific range.
*
* @param values ColumnVector to be performed with range check
* @param minValue Named parameter for function to create Scalar representing range minimum value
* @param maxValue Named parameter for function to create Scalar representing range maximum value
* @param inclusiveMin Whether the min value is included in the valid range or not
* @param inclusiveMax Whether the max value is included in the valid range or not
* @throws IllegalStateException if any values in the column are not within the specified range
*/
private def assertValuesInRange(values: ColumnVector,
minValue: => Scalar,
maxValue: => Scalar,
inclusiveMin: Boolean = true,
inclusiveMax: Boolean = true): Unit = {

def throwIfAny(cv: ColumnVector): Unit = {
withResource(cv) { cv =>
withResource(cv.any()) { isAny =>
if (isAny.getBoolean) {
throw new IllegalStateException(GpuCast.INVALID_INPUT_MESSAGE)
}
}
}
}

withResource(minValue) { minValue =>
throwIfAny(if (inclusiveMin) {
values.lessThan(minValue)
} else {
values.lessOrEqualTo(minValue)
})
}

withResource(maxValue) { maxValue =>
throwIfAny(if (inclusiveMax) {
values.greaterThan(maxValue)
} else {
values.greaterOrEqualTo(maxValue)
})
}
}

/**
* Detects outlier values of a column given with specific range, and replaces them with
* a inputted substitution value.
*
* @param values ColumnVector to be performed with range check
* @param minValue Named parameter for function to create Scalar representing range minimum value
* @param maxValue Named parameter for function to create Scalar representing range maximum value
* @param replaceValue Named parameter for function to create scalar to substitute outlier value
* @param inclusiveMin Whether the min value is included in the valid range or not
* @param inclusiveMax Whether the max value is included in the valid range or not
*/
private def replaceOutOfRangeValues(values: ColumnVector,
minValue: => Scalar,
maxValue: => Scalar,
replaceValue: => Scalar,
inclusiveMin: Boolean = true,
inclusiveMax: Boolean = true): ColumnVector = {

withResource(minValue) { minValue =>
withResource(maxValue) { maxValue =>
val minPredicate = if (inclusiveMin) {
values.lessThan(minValue)
} else {
values.lessOrEqualTo(minValue)
}
withResource(minPredicate) { minPredicate =>
val maxPredicate = if (inclusiveMax) {
values.greaterThan(maxValue)
} else {
values.greaterOrEqualTo(maxValue)
}
withResource(maxPredicate) { maxPredicate =>
withResource(maxPredicate.or(minPredicate)) { rangePredicate =>
withResource(replaceValue) { nullScalar =>
rangePredicate.ifElse(nullScalar, values)
}
}
}
}
}
}
}

private[rapids] def castDecimalToDecimal(input: ColumnVector,
from: DecimalType,
to: DecimalType,
ansiMode: Boolean): ColumnVector = {

val isFrom32Bit = DecimalType.is32BitDecimalType(from)
val isTo32Bit = DecimalType.is32BitDecimalType(to)
val cudfDecimal = DecimalUtil.createCudfDecimal(to.precision, to.scale)

def castCheckedDecimal(checkedInput: ColumnVector): ColumnVector = {
to.scale - from.scale match {
case 0 =>
if (isFrom32Bit == isTo32Bit) {
checkedInput.incRefCount()
} else {
// the input is already checked, just cast it
checkedInput.castTo(cudfDecimal)
}
case diff if diff > 0 =>
checkedInput.castTo(cudfDecimal)
case _ =>
withResource(checkedInput.round(to.scale, ai.rapids.cudf.RoundMode.HALF_UP)) {
rounded => rounded.castTo(cudfDecimal)
}
}
}

def checkForOverflow: ColumnVector = {
// Check whether there exists overflow during promoting precision or not.
// We do NOT use `Scalar.fromDecimal(-to.scale, math.pow(10, 18).toLong)` here, because
// cuDF binaryOperation on decimal will rescale right input to fit the left one.
// The rescaling may lead to overflow.
// absBound is the maximum value that the input column can have before being casted
val prec = to.precision + from.scale - to.scale

// if target precision is greater than and smaller than the max/min precision that can
// be held in the input, go ahead with the cast without further checking
if (isFrom32Bit && prec > Decimal.MAX_INT_DIGITS ||
!isFrom32Bit && prec > Decimal.MAX_LONG_DIGITS) {
return input.incRefCount()
}
val (minValueScalar, maxValueScalar) = if (!isFrom32Bit) {
val absBound = math.pow(10, prec).toLong
(Scalar.fromDecimal(-from.scale, -absBound), Scalar.fromDecimal(-from.scale, absBound))
} else {
val absBound = math.pow(10, prec).toInt
(Scalar.fromDecimal(-from.scale, -absBound), Scalar.fromDecimal(-from.scale, absBound))
}
val checkedInput = if (ansiMode) {
assertValuesInRange(input,
minValue = minValueScalar,
maxValue = maxValueScalar,
inclusiveMin = false, inclusiveMax = false)
input.incRefCount()
} else {
replaceOutOfRangeValues(input,
minValue = minValueScalar,
maxValue = maxValueScalar,
replaceValue = Scalar.fromNull(input.getType),
inclusiveMin = false, inclusiveMax = false)
}

checkedInput
}

if (to.scale <= from.scale) {
if (!isFrom32Bit && isTo32Bit) {
// check for overflow when 64bit => 32bit
withResource(checkForOverflow) { checkedInput =>
castCheckedDecimal(checkedInput)
}
} else {
if (to.scale < 0 && !SQLConf.get.allowNegativeScaleOfDecimalEnabled) {
throw new IllegalStateException(s"Negative scale is not allowed: ${to.scale}. " +
s"You can use spark.sql.legacy.allowNegativeScaleOfDecimal=true " +
s"to enable legacy mode to allow it.")
}
castCheckedDecimal(input)
}
} else {
// from.scale > to.scale
withResource(checkForOverflow) { checkedInput =>
castCheckedDecimal(checkedInput)
}
}
}
}

/**
Expand Down Expand Up @@ -570,6 +395,91 @@ case class GpuCast(
}
}

/**
* Asserts that all values in a column are within the specific range.
*
* @param values ColumnVector to be performed with range check
* @param minValue Named parameter for function to create Scalar representing range minimum value
* @param maxValue Named parameter for function to create Scalar representing range maximum value
* @param inclusiveMin Whether the min value is included in the valid range or not
* @param inclusiveMax Whether the max value is included in the valid range or not
* @throws IllegalStateException if any values in the column are not within the specified range
*/
private def assertValuesInRange(values: ColumnVector,
minValue: => Scalar,
maxValue: => Scalar,
inclusiveMin: Boolean = true,
inclusiveMax: Boolean = true): Unit = {

def throwIfAny(cv: ColumnVector): Unit = {
withResource(cv) { cv =>
withResource(cv.any()) { isAny =>
if (isAny.getBoolean) {
throw new IllegalStateException(GpuCast.INVALID_INPUT_MESSAGE)
}
}
}
}

withResource(minValue) { minValue =>
throwIfAny(if (inclusiveMin) {
values.lessThan(minValue)
} else {
values.lessOrEqualTo(minValue)
})
}

withResource(maxValue) { maxValue =>
throwIfAny(if (inclusiveMax) {
values.greaterThan(maxValue)
} else {
values.greaterOrEqualTo(maxValue)
})
}
}

/**
* Detects outlier values of a column given with specific range, and replaces them with
* a inputted substitution value.
*
* @param values ColumnVector to be performed with range check
* @param minValue Named parameter for function to create Scalar representing range minimum value
* @param maxValue Named parameter for function to create Scalar representing range maximum value
* @param replaceValue Named parameter for function to create scalar to substitute outlier value
* @param inclusiveMin Whether the min value is included in the valid range or not
* @param inclusiveMax Whether the max value is included in the valid range or not
*/
private def replaceOutOfRangeValues(values: ColumnVector,
minValue: => Scalar,
maxValue: => Scalar,
replaceValue: => Scalar,
inclusiveMin: Boolean = true,
inclusiveMax: Boolean = true): ColumnVector = {

withResource(minValue) { minValue =>
withResource(maxValue) { maxValue =>
val minPredicate = if (inclusiveMin) {
values.lessThan(minValue)
} else {
values.lessOrEqualTo(minValue)
}
withResource(minPredicate) { minPredicate =>
val maxPredicate = if (inclusiveMax) {
values.greaterThan(maxValue)
} else {
values.greaterOrEqualTo(maxValue)
}
withResource(maxPredicate) { maxPredicate =>
withResource(maxPredicate.or(minPredicate)) { rangePredicate =>
withResource(replaceValue) { nullScalar =>
rangePredicate.ifElse(nullScalar, values)
}
}
}
}
}
}
}

private def castTimestampToString(input: GpuColumnVector): ColumnVector = {
withResource(input.getBase.castTo(DType.TIMESTAMP_MICROSECONDS)) { micros =>
Expand Down Expand Up @@ -1146,4 +1056,93 @@ case class GpuCast(
}
}
}

private def castDecimalToDecimal(input: ColumnVector,
from: DecimalType,
to: DecimalType,
ansiMode: Boolean): ColumnVector = {

val isFrom32Bit = DecimalType.is32BitDecimalType(from)
val isTo32Bit = DecimalType.is32BitDecimalType(to)
val cudfDecimal = DecimalUtil.createCudfDecimal(to.precision, to.scale)

def castCheckedDecimal(checkedInput: ColumnVector): ColumnVector = {
to.scale - from.scale match {
case 0 =>
if (isFrom32Bit == isTo32Bit) {
checkedInput.incRefCount()
} else {
// the input is already checked, just cast it
checkedInput.castTo(cudfDecimal)
}
case diff if diff > 0 =>
checkedInput.castTo(cudfDecimal)
case _ =>
withResource(checkedInput.round(to.scale, ai.rapids.cudf.RoundMode.HALF_UP)) {
rounded => rounded.castTo(cudfDecimal)
}
}
}

def checkForOverflow: ColumnVector = {
// Check whether there exists overflow during promoting precision or not.
// We do NOT use `Scalar.fromDecimal(-to.scale, math.pow(10, 18).toLong)` here, because
// cuDF binaryOperation on decimal will rescale right input to fit the left one.
// The rescaling may lead to overflow.
// absBound is the maximum value that the input column can have before being casted
val prec = to.precision + from.scale - to.scale

// if target precision is greater than and smaller than the max/min precision that can
// be held in the input, go ahead with the cast without further checking
// When we support 128 bit Decimals we should add a check for that
// if (isFrom32Bit && prec > Decimal.MAX_INT_DIGITS ||
// !isFrom32Bit && prec > Decimal.MAX_LONG_DIGITS)
if (isFrom32Bit && prec > Decimal.MAX_INT_DIGITS) {
return input.incRefCount()
}
val (minValueScalar, maxValueScalar) = if (!isFrom32Bit) {
val absBound = math.pow(10, prec).toLong
(Scalar.fromDecimal(-from.scale, -absBound), Scalar.fromDecimal(-from.scale, absBound))
} else {
val absBound = math.pow(10, prec).toInt
(Scalar.fromDecimal(-from.scale, -absBound), Scalar.fromDecimal(-from.scale, absBound))
}
val checkedInput = if (ansiMode) {
assertValuesInRange(input,
minValue = minValueScalar,
maxValue = maxValueScalar,
inclusiveMin = false, inclusiveMax = false)
input.incRefCount()
} else {
replaceOutOfRangeValues(input,
minValue = minValueScalar,
maxValue = maxValueScalar,
replaceValue = Scalar.fromNull(input.getType),
inclusiveMin = false, inclusiveMax = false)
}

checkedInput
}

if (to.scale <= from.scale) {
if (!isFrom32Bit && isTo32Bit) {
// check for overflow when 64bit => 32bit
withResource(checkForOverflow) { checkedInput =>
castCheckedDecimal(checkedInput)
}
} else {
if (to.scale < 0 && !SQLConf.get.allowNegativeScaleOfDecimalEnabled) {
throw new IllegalStateException(s"Negative scale is not allowed: ${to.scale}. " +
s"You can use spark.sql.legacy.allowNegativeScaleOfDecimal=true " +
s"to enable legacy mode to allow it.")
}
castCheckedDecimal(input)
}
} else {
// from.scale > to.scale
withResource(checkForOverflow) { checkedInput =>
castCheckedDecimal(checkedInput)
}
}
}
}
Loading

0 comments on commit 8702bef

Please sign in to comment.