diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala index f996df38307..f3ea87e9223 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala @@ -34,7 +34,6 @@ class CastExprMeta[INPUT <: CastBase]( rule: DataFromReplacementRule) extends UnaryExprMeta[INPUT](cast, conf, parent, rule) { - private val castExpr = if (ansiEnabled) "ansi_cast" else "cast" val fromType = cast.child.dataType val toType = cast.dataType @@ -367,7 +366,7 @@ case class GpuCast( } val longStrings = withResource(trimmed.matchesRe(regex)) { regexMatches => if (ansiMode) { - withResource(regexMatches.all(DType.BOOL8)) { allRegexMatches => + withResource(regexMatches.all()) { allRegexMatches => if (!allRegexMatches.getBoolean) { throw new NumberFormatException(GpuCast.INVALID_INPUT_MESSAGE) } @@ -546,7 +545,7 @@ case class GpuCast( withResource(input.contains(boolStrings)) { validBools => // in ansi mode, fail if any values are not valid bool strings if (ansiEnabled) { - withResource(validBools.all(DType.BOOL8)) { isAllBool => + withResource(validBools.all()) { isAllBool => if (!isAllBool.getBoolean) { throw new IllegalStateException(GpuCast.INVALID_INPUT_MESSAGE) } @@ -964,7 +963,7 @@ case class GpuCast( // replace values less than minValue with null val gtEqMinOrNull = withResource(values.greaterOrEqualTo(minValue)) { isGtEqMin => if (ansiMode) { - withResource(isGtEqMin.all(DType.BOOL8)) { all => + withResource(isGtEqMin.all()) { all => if (!all.getBoolean) { throw new NumberFormatException(GpuCast.INVALID_INPUT_MESSAGE) } @@ -977,7 +976,7 @@ case class GpuCast( val ltEqMaxOrNull = withResource(gtEqMinOrNull) { gtEqMinOrNull => withResource(gtEqMinOrNull.lessOrEqualTo(maxValue)) { isLtEqMax => if (ansiMode) { - withResource(isLtEqMax.all(DType.BOOL8)) { all => + withResource(isLtEqMax.all()) { all => if (!all.getBoolean) { throw new NumberFormatException(GpuCast.INVALID_INPUT_MESSAGE) } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/decimalExpressions.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/decimalExpressions.scala index e0df29805a1..aeda82eff93 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/decimalExpressions.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/decimalExpressions.scala @@ -56,7 +56,7 @@ case class GpuUnscaledValue(child: Expression) extends GpuUnaryExpression { override def toString: String = s"UnscaledValue($child)" override protected def doColumnar(input: GpuColumnVector): ColumnVector = { - withResource(input.getBase.logicalCastTo(DType.INT64)) { view => + withResource(input.getBase.bitCastTo(DType.INT64)) { view => view.copyToColumnVector() } } @@ -85,13 +85,13 @@ case class GpuMakeDecimal( } withResource(overflowed) { overflowed => withResource(Scalar.fromNull(outputType)) { nullVal => - withResource(base.logicalCastTo(outputType)) { view => + withResource(base.bitCastTo(outputType)) { view => overflowed.ifElse(nullVal, view) } } } } else { - withResource(base.logicalCastTo(outputType)) { view => + withResource(base.bitCastTo(outputType)) { view => view.copyToColumnVector() } } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala index 0fcd3b4bfce..6f6e55322ee 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.rapids import ai.rapids.cudf import ai.rapids.cudf.{Aggregation, AggregationOnColumn, ColumnVector, DType} +import ai.rapids.cudf.Aggregation.NullPolicy import com.nvidia.spark.rapids._ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult @@ -180,13 +181,11 @@ abstract case class CudfAggregate(ref: Expression) extends GpuUnevaluable { } class CudfCount(ref: Expression) extends CudfAggregate(ref) { - // includeNulls set to false in count aggregate to exclude nulls while calculating count(column) - val includeNulls = false override val updateReductionAggregate: cudf.ColumnVector => cudf.Scalar = (col: cudf.ColumnVector) => cudf.Scalar.fromLong(col.getRowCount - col.getNullCount) override val mergeReductionAggregate: cudf.ColumnVector => cudf.Scalar = (col: cudf.ColumnVector) => col.sum - override lazy val updateAggregate: Aggregation = Aggregation.count(includeNulls) + override lazy val updateAggregate: Aggregation = Aggregation.count(NullPolicy.EXCLUDE) override lazy val mergeAggregate: Aggregation = Aggregation.sum() override def toString(): String = "CudfCount" } @@ -241,7 +240,7 @@ class CudfMin(ref: Expression) extends CudfAggregate(ref) { } abstract class CudfFirstLastBase(ref: Expression) extends CudfAggregate(ref) { - val includeNulls: Boolean + val includeNulls: NullPolicy val offset: Int override val updateReductionAggregate: cudf.ColumnVector => cudf.Scalar = @@ -253,22 +252,22 @@ abstract class CudfFirstLastBase(ref: Expression) extends CudfAggregate(ref) { } class CudfFirstIncludeNulls(ref: Expression) extends CudfFirstLastBase(ref) { - override val includeNulls: Boolean = true + override val includeNulls: NullPolicy = NullPolicy.INCLUDE override val offset: Int = 0 } class CudfFirstExcludeNulls(ref: Expression) extends CudfFirstLastBase(ref) { - override val includeNulls: Boolean = false + override val includeNulls: NullPolicy = NullPolicy.EXCLUDE override val offset: Int = 0 } class CudfLastIncludeNulls(ref: Expression) extends CudfFirstLastBase(ref) { - override val includeNulls: Boolean = true + override val includeNulls: NullPolicy = NullPolicy.INCLUDE override val offset: Int = -1 } class CudfLastExcludeNulls(ref: Expression) extends CudfFirstLastBase(ref) { - override val includeNulls: Boolean = false + override val includeNulls: NullPolicy = NullPolicy.EXCLUDE override val offset: Int = -1 } @@ -399,7 +398,7 @@ case class GpuCount(children: Seq[Expression]) extends GpuDeclarativeAggregate // we could support it by doing an `Aggregation.nunique(false)` override lazy val windowInputProjection: Seq[Expression] = inputProjection override def windowAggregation(inputs: Seq[(ColumnVector, Int)]): AggregationOnColumn = - Aggregation.count(false).onColumn(inputs.head._2) + Aggregation.count(NullPolicy.EXCLUDE).onColumn(inputs.head._2) } case class GpuAverage(child: Expression) extends GpuDeclarativeAggregate diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuWindowInPandasExecBase.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuWindowInPandasExecBase.scala index 53812b31095..4e32dc219e5 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuWindowInPandasExecBase.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuWindowInPandasExecBase.scala @@ -21,6 +21,7 @@ import scala.collection.mutable.ArrayBuffer import ai.rapids.cudf import ai.rapids.cudf.{Aggregation, OrderByArg} +import ai.rapids.cudf.Aggregation.NullPolicy import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.GpuMetric._ import com.nvidia.spark.rapids.RapidsPluginImplicits._ @@ -122,7 +123,7 @@ class GroupingIterator( withResource(GpuColumnVector.from(projected)) { table => table .groupBy(partitionIndices:_*) - .aggregate(Aggregation.count(true).onColumn(0)) + .aggregate(Aggregation.count(NullPolicy.INCLUDE).onColumn(0)) } } val orderedTable = withResource(cntTable) { table =>