Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for split and getArrayIndex #527

Merged
merged 3 commits into from
Aug 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ Name | SQL Function(s) | Description | Default Value | Notes
<a name="sql.expression.Expm1"></a>spark.rapids.sql.expression.Expm1|`expm1`|Euler's number e raised to a power minus 1|true|None|
<a name="sql.expression.Floor"></a>spark.rapids.sql.expression.Floor|`floor`|Floor of a number|true|None|
<a name="sql.expression.FromUnixTime"></a>spark.rapids.sql.expression.FromUnixTime|`from_unixtime`|Get the string from a unix timestamp|true|None|
<a name="sql.expression.GetArrayItem"></a>spark.rapids.sql.expression.GetArrayItem| |Gets the field at `ordinal` in the Array|true|None|
<a name="sql.expression.GreaterThan"></a>spark.rapids.sql.expression.GreaterThan|`>`|> operator|true|None|
<a name="sql.expression.GreaterThanOrEqual"></a>spark.rapids.sql.expression.GreaterThanOrEqual|`>=`|>= operator|true|None|
<a name="sql.expression.Hour"></a>spark.rapids.sql.expression.Hour|`hour`|Returns the hour component of the string/timestamp|true|None|
Expand Down Expand Up @@ -186,6 +187,7 @@ Name | SQL Function(s) | Description | Default Value | Notes
<a name="sql.expression.StringLocate"></a>spark.rapids.sql.expression.StringLocate|`position`, `locate`|Substring search operator|true|None|
<a name="sql.expression.StringRPad"></a>spark.rapids.sql.expression.StringRPad|`rpad`|Pad a string on the right|true|None|
<a name="sql.expression.StringReplace"></a>spark.rapids.sql.expression.StringReplace|`replace`|StringReplace operator|true|None|
<a name="sql.expression.StringSplit"></a>spark.rapids.sql.expression.StringSplit|`split`|Splits `str` around occurrences that match `regex`|true|None|
<a name="sql.expression.StringTrim"></a>spark.rapids.sql.expression.StringTrim|`trim`|StringTrim operator|true|None|
<a name="sql.expression.StringTrimLeft"></a>spark.rapids.sql.expression.StringTrimLeft|`ltrim`|StringTrimLeft operator|true|None|
<a name="sql.expression.StringTrimRight"></a>spark.rapids.sql.expression.StringTrimRight|`rtrim`|StringTrimRight operator|true|None|
Expand Down
14 changes: 14 additions & 0 deletions integration_tests/src/main/python/string_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,20 @@
def mk_str_gen(pattern):
return StringGen(pattern).with_special_case('').with_special_pattern('.{0,10}')

# Because of limitations in array support we need to combine these two together to make
# this work. This should be split up into separate tests once support is better.
def test_split_with_array_index():
data_gen = mk_str_gen('([ABC]{0,3}_?){0,7}')
delim = '_'
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
'split(a, "AB")[0]',
'split(a, "_")[1]',
'split(a, "_")[null]',
'split(a, "_")[3]',
'split(a, "_")[0]',
'split(a, "_")[-1]'))

@pytest.mark.parametrize('data_gen,delim', [(mk_str_gen('([ABC]{0,3}_?){0,7}'), '_'),
(mk_str_gen('([MNP_]{0,3}\\.?){0,5}'), '.'),
(mk_str_gen('([123]{0,3}\\^?){0,5}'), '^')], ids=idfn)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.nvidia.spark.rapids;

import ai.rapids.cudf.ColumnViewAccess;
import ai.rapids.cudf.DType;
import ai.rapids.cudf.HostColumnVector;
import ai.rapids.cudf.Scalar;
Expand Down Expand Up @@ -192,12 +193,22 @@ protected static final DataType getSparkType(DType type) {
case TIMESTAMP_DAYS:
return DataTypes.DateType;
case TIMESTAMP_MICROSECONDS:
return DataTypes.TimestampType; // TODO need to verify that the TimeUnits are correct
return DataTypes.TimestampType;
case STRING:
return DataTypes.StringType;
default:
throw new IllegalArgumentException(type + " is not supported by spark yet.");
}
}

protected static final DataType getSparkTypeFrom(ColumnViewAccess access) {
DType type = access.getDataType();
if (type == DType.LIST) {
try (ColumnViewAccess child = access.getChildColumnViewAccess(0)) {
return new ArrayType(getSparkTypeFrom(child), true);
}
} else {
return getSparkType(type);
}
}

Expand Down Expand Up @@ -300,7 +311,7 @@ public static final ColumnarBatch from(Table table, int startColIndex, int until
* but not both.
*/
public static final GpuColumnVector from(ai.rapids.cudf.ColumnVector cudfCv) {
return new GpuColumnVector(getSparkType(cudfCv.getType()), cudfCv);
return new GpuColumnVector(getSparkTypeFrom(cudfCv), cudfCv);
}

public static final GpuColumnVector from(Scalar scalar, int count) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public static ColumnarBatch from(ContiguousTable contigTable) {
try {
for (int i = 0; i < numColumns; ++i) {
ColumnVector v = table.getColumn(i);
DataType type = getSparkType(v.getType());
DataType type = getSparkTypeFrom(v);
columns[i] = new GpuColumnVectorFromBuffer(type, v.incRefCount(), buffer);
}
return new ColumnarBatch(columns, (int) rows);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,11 @@ object GpuOverrides {
"\\S", "\\v", "\\V", "\\w", "\\w", "\\p", "$", "\\b", "\\B", "\\A", "\\G", "\\Z", "\\z", "\\R",
"?", "|", "(", ")", "{", "}", "\\k", "\\Q", "\\E", ":", "!", "<=", ">")

def canRegexpBeTreatedLikeARegularString(strLit: UTF8String): Boolean = {
val s = strLit.toString
!regexList.exists(pattern => s.contains(pattern))
}

@scala.annotation.tailrec
def extractLit(exp: Expression): Option[Literal] = exp match {
case l: Literal => Some(l)
Expand Down Expand Up @@ -1328,6 +1333,12 @@ object GpuOverrides {
pad: Expression): GpuExpression =
GpuStringRPad(str, width, pad)
}),
expr[StringSplit](
"Splits `str` around occurrences that match `regex`",
(in, conf, p, r) => new GpuStringSplitMeta(in, conf, p, r)),
expr[GetArrayItem](
"Gets the field at `ordinal` in the Array",
(in, conf, p, r) => new GpuGetArrayItemMeta(in, conf, p, r)),
expr[StringLocate](
"Substring search operator",
(in, conf, p, r) => new TernaryExprMeta[StringLocate](in, conf, p, r) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.rapids

import ai.rapids.cudf.{ColumnVector, Scalar}
import com.nvidia.spark.rapids.{BinaryExprMeta, ConfKeysAndIncompat, GpuBinaryExpression, GpuColumnVector, GpuExpression, GpuOverrides, RapidsConf, RapidsMeta}

import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExtractValue, GetArrayItem}
import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, DataType, IntegralType}

class GpuGetArrayItemMeta(
expr: GetArrayItem,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: ConfKeysAndIncompat)
extends BinaryExprMeta[GetArrayItem](expr, conf, parent, rule) {
import GpuOverrides._

override def tagExprForGpu(): Unit = {
if (!isLit(expr.ordinal)) {
willNotWorkOnGpu("only literal ordinals are supported")
}
}
override def convertToGpu(
arr: Expression,
ordinal: Expression): GpuExpression =
GpuGetArrayItem(arr, ordinal)

def isSupported(t: DataType) = t match {
// For now we will only do one level of array type support
case a : ArrayType => isSupportedType(a.elementType)
case _ => isSupportedType(t)
}

override def areAllSupportedTypes(types: DataType*): Boolean = types.forall(isSupported)
}

/**
* Returns the field at `ordinal` in the Array `child`.
*
* We need to do type checking here as `ordinal` expression maybe unresolved.
*/
case class GpuGetArrayItem(child: Expression, ordinal: Expression)
extends GpuBinaryExpression with ExpectsInputTypes with ExtractValue {

// We have done type checking for child in `ExtractValue`, so only need to check the `ordinal`.
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegralType)

override def toString: String = s"$child[$ordinal]"
override def sql: String = s"${child.sql}[${ordinal.sql}]"

override def left: Expression = child
override def right: Expression = ordinal
// Eventually we need something more full featured like
// GetArrayItemUtil.computeNullabilityFromArray
override def nullable: Boolean = true
override def dataType: DataType = child.dataType.asInstanceOf[ArrayType].elementType

override def doColumnar(lhs: GpuColumnVector, rhs: GpuColumnVector): GpuColumnVector =
throw new IllegalStateException("This is not supported yet")

override def doColumnar(lhs: Scalar, rhs: GpuColumnVector): GpuColumnVector =
throw new IllegalStateException("This is not supported yet")

override def doColumnar(lhs: GpuColumnVector, ordinal: Scalar): GpuColumnVector = {
// Need to handle negative indexes...
if (ordinal.isValid && ordinal.getInt >= 0) {
GpuColumnVector.from(lhs.getBase.extractListElement(ordinal.getInt))
} else {
withResource(Scalar.fromNull(GpuColumnVector.getRapidsType(dataType))) { nullScalar =>
GpuColumnVector.from(ColumnVector.fromScalar(nullScalar, lhs.getRowCount.toInt))
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import ai.rapids.cudf.{ColumnVector, DType, PadSide, Scalar, Table}
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.RapidsPluginImplicits._

import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ImplicitCastInputTypes, NullIntolerant, Predicate, SubstringIndex}
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ImplicitCastInputTypes, NullIntolerant, Predicate, StringSplit, SubstringIndex}
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.unsafe.types.UTF8String
Expand Down Expand Up @@ -768,4 +768,96 @@ case class GpuStringRPad(str: Expression, len: Expression, pad: Expression)
def this(str: Expression, len: Expression) = {
this(str, len, GpuLiteral(" ", StringType))
}
}

class GpuStringSplitMeta(
expr: StringSplit,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: ConfKeysAndIncompat)
extends TernaryExprMeta[StringSplit](expr, conf, parent, rule) {
import GpuOverrides._

override def tagExprForGpu(): Unit = {
val regexp = extractLit(expr.regex)
if (regexp.isEmpty) {
willNotWorkOnGpu("only literal regexp values are supported")
} else {
val str = regexp.get.value.asInstanceOf[UTF8String]
if (str != null) {
if (!canRegexpBeTreatedLikeARegularString(str)) {
willNotWorkOnGpu("regular expressions are not supported yet")
}
if (str.numChars() == 0) {
willNotWorkOnGpu("An empty regex is not supported yet")
}
} else {
willNotWorkOnGpu("null regex is not supported yet")
}
}
if (!isLit(expr.limit)) {
willNotWorkOnGpu("only literal limit is supported")
}
}
override def convertToGpu(
str: Expression,
regexp: Expression,
limit: Expression): GpuExpression =
GpuStringSplit(str, regexp, limit)

// For now we support all of the possible input and output types for this operator
override def areAllSupportedTypes(types: DataType*): Boolean = true
}

case class GpuStringSplit(str: Expression, regex: Expression, limit: Expression)
extends GpuTernaryExpression with ImplicitCastInputTypes {

override def dataType: DataType = ArrayType(StringType)
override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType)
override def children: Seq[Expression] = str :: regex :: limit :: Nil

def this(exp: Expression, regex: Expression) = this(exp, regex, GpuLiteral(-1, IntegerType))

override def prettyName: String = "split"

override def doColumnar(str: GpuColumnVector, regex: Scalar, limit: Scalar): GpuColumnVector = {
val intLimit = limit.getInt
GpuColumnVector.from(str.getBase.stringSplitRecord(regex, intLimit))
}

override def doColumnar(
str: GpuColumnVector,
regex: GpuColumnVector,
limit: GpuColumnVector): GpuColumnVector =
throw new IllegalStateException("This is not supported yet")

override def doColumnar(
str: Scalar,
regex: GpuColumnVector,
limit: GpuColumnVector): GpuColumnVector =
throw new IllegalStateException("This is not supported yet")

override def doColumnar(
str: Scalar,
regex: Scalar,
limit: GpuColumnVector): GpuColumnVector =
throw new IllegalStateException("This is not supported yet")

override def doColumnar(
str: Scalar,
regex: GpuColumnVector,
limit: Scalar): GpuColumnVector =
throw new IllegalStateException("This is not supported yet")

override def doColumnar(
str: GpuColumnVector,
regex: Scalar,
limit: GpuColumnVector): GpuColumnVector =
throw new IllegalStateException("This is not supported yet")

override def doColumnar(
str: GpuColumnVector,
regex: GpuColumnVector,
limit: Scalar): GpuColumnVector =
throw new IllegalStateException("This is not supported yet")
}