Skip to content

Commit

Permalink
Add support for split and getArrayIndex (NVIDIA#527)
Browse files Browse the repository at this point in the history
Signed-off-by: Robert (Bobby) Evans <bobby@apache.org>
  • Loading branch information
revans2 authored Aug 10, 2020
1 parent 28ad4a3 commit b88039f
Show file tree
Hide file tree
Showing 7 changed files with 223 additions and 4 deletions.
2 changes: 2 additions & 0 deletions docs/configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ Name | SQL Function(s) | Description | Default Value | Notes
<a name="sql.expression.Expm1"></a>spark.rapids.sql.expression.Expm1|`expm1`|Euler's number e raised to a power minus 1|true|None|
<a name="sql.expression.Floor"></a>spark.rapids.sql.expression.Floor|`floor`|Floor of a number|true|None|
<a name="sql.expression.FromUnixTime"></a>spark.rapids.sql.expression.FromUnixTime|`from_unixtime`|Get the string from a unix timestamp|true|None|
<a name="sql.expression.GetArrayItem"></a>spark.rapids.sql.expression.GetArrayItem| |Gets the field at `ordinal` in the Array|true|None|
<a name="sql.expression.GreaterThan"></a>spark.rapids.sql.expression.GreaterThan|`>`|> operator|true|None|
<a name="sql.expression.GreaterThanOrEqual"></a>spark.rapids.sql.expression.GreaterThanOrEqual|`>=`|>= operator|true|None|
<a name="sql.expression.Hour"></a>spark.rapids.sql.expression.Hour|`hour`|Returns the hour component of the string/timestamp|true|None|
Expand Down Expand Up @@ -186,6 +187,7 @@ Name | SQL Function(s) | Description | Default Value | Notes
<a name="sql.expression.StringLocate"></a>spark.rapids.sql.expression.StringLocate|`position`, `locate`|Substring search operator|true|None|
<a name="sql.expression.StringRPad"></a>spark.rapids.sql.expression.StringRPad|`rpad`|Pad a string on the right|true|None|
<a name="sql.expression.StringReplace"></a>spark.rapids.sql.expression.StringReplace|`replace`|StringReplace operator|true|None|
<a name="sql.expression.StringSplit"></a>spark.rapids.sql.expression.StringSplit|`split`|Splits `str` around occurrences that match `regex`|true|None|
<a name="sql.expression.StringTrim"></a>spark.rapids.sql.expression.StringTrim|`trim`|StringTrim operator|true|None|
<a name="sql.expression.StringTrimLeft"></a>spark.rapids.sql.expression.StringTrimLeft|`ltrim`|StringTrimLeft operator|true|None|
<a name="sql.expression.StringTrimRight"></a>spark.rapids.sql.expression.StringTrimRight|`rtrim`|StringTrimRight operator|true|None|
Expand Down
14 changes: 14 additions & 0 deletions integration_tests/src/main/python/string_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,20 @@
def mk_str_gen(pattern):
return StringGen(pattern).with_special_case('').with_special_pattern('.{0,10}')

# Because of limitations in array support we need to combine these two together to make
# this work. This should be split up into separate tests once support is better.
def test_split_with_array_index():
data_gen = mk_str_gen('([ABC]{0,3}_?){0,7}')
delim = '_'
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
'split(a, "AB")[0]',
'split(a, "_")[1]',
'split(a, "_")[null]',
'split(a, "_")[3]',
'split(a, "_")[0]',
'split(a, "_")[-1]'))

@pytest.mark.parametrize('data_gen,delim', [(mk_str_gen('([ABC]{0,3}_?){0,7}'), '_'),
(mk_str_gen('([MNP_]{0,3}\\.?){0,5}'), '.'),
(mk_str_gen('([123]{0,3}\\^?){0,5}'), '^')], ids=idfn)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.nvidia.spark.rapids;

import ai.rapids.cudf.ColumnViewAccess;
import ai.rapids.cudf.DType;
import ai.rapids.cudf.HostColumnVector;
import ai.rapids.cudf.Scalar;
Expand Down Expand Up @@ -192,12 +193,22 @@ protected static final DataType getSparkType(DType type) {
case TIMESTAMP_DAYS:
return DataTypes.DateType;
case TIMESTAMP_MICROSECONDS:
return DataTypes.TimestampType; // TODO need to verify that the TimeUnits are correct
return DataTypes.TimestampType;
case STRING:
return DataTypes.StringType;
default:
throw new IllegalArgumentException(type + " is not supported by spark yet.");
}
}

protected static final DataType getSparkTypeFrom(ColumnViewAccess access) {
DType type = access.getDataType();
if (type == DType.LIST) {
try (ColumnViewAccess child = access.getChildColumnViewAccess(0)) {
return new ArrayType(getSparkTypeFrom(child), true);
}
} else {
return getSparkType(type);
}
}

Expand Down Expand Up @@ -300,7 +311,7 @@ public static final ColumnarBatch from(Table table, int startColIndex, int until
* but not both.
*/
public static final GpuColumnVector from(ai.rapids.cudf.ColumnVector cudfCv) {
return new GpuColumnVector(getSparkType(cudfCv.getType()), cudfCv);
return new GpuColumnVector(getSparkTypeFrom(cudfCv), cudfCv);
}

public static final GpuColumnVector from(Scalar scalar, int count) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public static ColumnarBatch from(ContiguousTable contigTable) {
try {
for (int i = 0; i < numColumns; ++i) {
ColumnVector v = table.getColumn(i);
DataType type = getSparkType(v.getType());
DataType type = getSparkTypeFrom(v);
columns[i] = new GpuColumnVectorFromBuffer(type, v.incRefCount(), buffer);
}
return new ColumnarBatch(columns, (int) rows);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,11 @@ object GpuOverrides {
"\\S", "\\v", "\\V", "\\w", "\\w", "\\p", "$", "\\b", "\\B", "\\A", "\\G", "\\Z", "\\z", "\\R",
"?", "|", "(", ")", "{", "}", "\\k", "\\Q", "\\E", ":", "!", "<=", ">")

def canRegexpBeTreatedLikeARegularString(strLit: UTF8String): Boolean = {
val s = strLit.toString
!regexList.exists(pattern => s.contains(pattern))
}

@scala.annotation.tailrec
def extractLit(exp: Expression): Option[Literal] = exp match {
case l: Literal => Some(l)
Expand Down Expand Up @@ -1328,6 +1333,12 @@ object GpuOverrides {
pad: Expression): GpuExpression =
GpuStringRPad(str, width, pad)
}),
expr[StringSplit](
"Splits `str` around occurrences that match `regex`",
(in, conf, p, r) => new GpuStringSplitMeta(in, conf, p, r)),
expr[GetArrayItem](
"Gets the field at `ordinal` in the Array",
(in, conf, p, r) => new GpuGetArrayItemMeta(in, conf, p, r)),
expr[StringLocate](
"Substring search operator",
(in, conf, p, r) => new TernaryExprMeta[StringLocate](in, conf, p, r) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.rapids

import ai.rapids.cudf.{ColumnVector, Scalar}
import com.nvidia.spark.rapids.{BinaryExprMeta, ConfKeysAndIncompat, GpuBinaryExpression, GpuColumnVector, GpuExpression, GpuOverrides, RapidsConf, RapidsMeta}

import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExtractValue, GetArrayItem}
import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, DataType, IntegralType}

class GpuGetArrayItemMeta(
expr: GetArrayItem,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: ConfKeysAndIncompat)
extends BinaryExprMeta[GetArrayItem](expr, conf, parent, rule) {
import GpuOverrides._

override def tagExprForGpu(): Unit = {
if (!isLit(expr.ordinal)) {
willNotWorkOnGpu("only literal ordinals are supported")
}
}
override def convertToGpu(
arr: Expression,
ordinal: Expression): GpuExpression =
GpuGetArrayItem(arr, ordinal)

def isSupported(t: DataType) = t match {
// For now we will only do one level of array type support
case a : ArrayType => isSupportedType(a.elementType)
case _ => isSupportedType(t)
}

override def areAllSupportedTypes(types: DataType*): Boolean = types.forall(isSupported)
}

/**
* Returns the field at `ordinal` in the Array `child`.
*
* We need to do type checking here as `ordinal` expression maybe unresolved.
*/
case class GpuGetArrayItem(child: Expression, ordinal: Expression)
extends GpuBinaryExpression with ExpectsInputTypes with ExtractValue {

// We have done type checking for child in `ExtractValue`, so only need to check the `ordinal`.
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegralType)

override def toString: String = s"$child[$ordinal]"
override def sql: String = s"${child.sql}[${ordinal.sql}]"

override def left: Expression = child
override def right: Expression = ordinal
// Eventually we need something more full featured like
// GetArrayItemUtil.computeNullabilityFromArray
override def nullable: Boolean = true
override def dataType: DataType = child.dataType.asInstanceOf[ArrayType].elementType

override def doColumnar(lhs: GpuColumnVector, rhs: GpuColumnVector): GpuColumnVector =
throw new IllegalStateException("This is not supported yet")

override def doColumnar(lhs: Scalar, rhs: GpuColumnVector): GpuColumnVector =
throw new IllegalStateException("This is not supported yet")

override def doColumnar(lhs: GpuColumnVector, ordinal: Scalar): GpuColumnVector = {
// Need to handle negative indexes...
if (ordinal.isValid && ordinal.getInt >= 0) {
GpuColumnVector.from(lhs.getBase.extractListElement(ordinal.getInt))
} else {
withResource(Scalar.fromNull(GpuColumnVector.getRapidsType(dataType))) { nullScalar =>
GpuColumnVector.from(ColumnVector.fromScalar(nullScalar, lhs.getRowCount.toInt))
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import ai.rapids.cudf.{ColumnVector, DType, PadSide, Scalar, Table}
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.RapidsPluginImplicits._

import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ImplicitCastInputTypes, NullIntolerant, Predicate, SubstringIndex}
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ImplicitCastInputTypes, NullIntolerant, Predicate, StringSplit, SubstringIndex}
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.unsafe.types.UTF8String
Expand Down Expand Up @@ -768,4 +768,96 @@ case class GpuStringRPad(str: Expression, len: Expression, pad: Expression)
def this(str: Expression, len: Expression) = {
this(str, len, GpuLiteral(" ", StringType))
}
}

class GpuStringSplitMeta(
expr: StringSplit,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: ConfKeysAndIncompat)
extends TernaryExprMeta[StringSplit](expr, conf, parent, rule) {
import GpuOverrides._

override def tagExprForGpu(): Unit = {
val regexp = extractLit(expr.regex)
if (regexp.isEmpty) {
willNotWorkOnGpu("only literal regexp values are supported")
} else {
val str = regexp.get.value.asInstanceOf[UTF8String]
if (str != null) {
if (!canRegexpBeTreatedLikeARegularString(str)) {
willNotWorkOnGpu("regular expressions are not supported yet")
}
if (str.numChars() == 0) {
willNotWorkOnGpu("An empty regex is not supported yet")
}
} else {
willNotWorkOnGpu("null regex is not supported yet")
}
}
if (!isLit(expr.limit)) {
willNotWorkOnGpu("only literal limit is supported")
}
}
override def convertToGpu(
str: Expression,
regexp: Expression,
limit: Expression): GpuExpression =
GpuStringSplit(str, regexp, limit)

// For now we support all of the possible input and output types for this operator
override def areAllSupportedTypes(types: DataType*): Boolean = true
}

case class GpuStringSplit(str: Expression, regex: Expression, limit: Expression)
extends GpuTernaryExpression with ImplicitCastInputTypes {

override def dataType: DataType = ArrayType(StringType)
override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType)
override def children: Seq[Expression] = str :: regex :: limit :: Nil

def this(exp: Expression, regex: Expression) = this(exp, regex, GpuLiteral(-1, IntegerType))

override def prettyName: String = "split"

override def doColumnar(str: GpuColumnVector, regex: Scalar, limit: Scalar): GpuColumnVector = {
val intLimit = limit.getInt
GpuColumnVector.from(str.getBase.stringSplitRecord(regex, intLimit))
}

override def doColumnar(
str: GpuColumnVector,
regex: GpuColumnVector,
limit: GpuColumnVector): GpuColumnVector =
throw new IllegalStateException("This is not supported yet")

override def doColumnar(
str: Scalar,
regex: GpuColumnVector,
limit: GpuColumnVector): GpuColumnVector =
throw new IllegalStateException("This is not supported yet")

override def doColumnar(
str: Scalar,
regex: Scalar,
limit: GpuColumnVector): GpuColumnVector =
throw new IllegalStateException("This is not supported yet")

override def doColumnar(
str: Scalar,
regex: GpuColumnVector,
limit: Scalar): GpuColumnVector =
throw new IllegalStateException("This is not supported yet")

override def doColumnar(
str: GpuColumnVector,
regex: Scalar,
limit: GpuColumnVector): GpuColumnVector =
throw new IllegalStateException("This is not supported yet")

override def doColumnar(
str: GpuColumnVector,
regex: GpuColumnVector,
limit: Scalar): GpuColumnVector =
throw new IllegalStateException("This is not supported yet")
}

0 comments on commit b88039f

Please sign in to comment.