Skip to content

Commit

Permalink
Merge pull request #2689 from NVIDIA/branch-21.06
Browse files Browse the repository at this point in the history
[auto-merge] branch-21.06 to branch-21.08 [skip ci] [bot]
  • Loading branch information
nvauto authored Jun 10, 2021
2 parents b4dc530 + 2f977dc commit 073b5a4
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 4 deletions.
32 changes: 32 additions & 0 deletions integration_tests/src/main/python/cast_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_are_equal_sql, assert_gpu_and_cpu_error
from data_gen import *
from functools import reduce
from spark_session import is_before_spark_311
from marks import allow_non_gpu
from pyspark.sql.types import *
from pyspark.sql.functions import array_contains, col, first, isnan, lit, element_at

def test_cast_empty_string_to_int():
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, StringGen(pattern="")).selectExpr(
'CAST(a as BYTE)',
'CAST(a as SHORT)',
'CAST(a as INTEGER)',
'CAST(a as LONG)'))

13 changes: 9 additions & 4 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -670,12 +670,17 @@ case class GpuCast(
// To avoid doing the expensive regex all the time, we will first check to see if we need
// to do it. The only time we do need to do it is when we have a '.' in any of the strings.
val data = input.getData
val hasDot = withResource(
ColumnView.fromDeviceBuffer(data, 0, DType.INT8, data.getLength.toInt)) { childData =>
withResource(GpuScalar.from('.'.toByte, ByteType)) { dot =>
childData.contains(dot)
val hasDot = if (data != null) {
withResource(
ColumnView.fromDeviceBuffer(data, 0, DType.INT8, data.getLength.toInt)) { childData =>
withResource(GpuScalar.from('.'.toByte, ByteType)) { dot =>
childData.contains(dot)
}
}
} else {
false
}

if (hasDot) {
withResource(input.extractRe("^([+\\-]?[0-9]+)(?:\\.[0-9]*)?$")) { table =>
table.getColumn(0).incRefCount()
Expand Down

0 comments on commit 073b5a4

Please sign in to comment.