Skip to content

Commit

Permalink
Decimal support for add and subtract (NVIDIA#1561)
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan Lee <ryanlee@nvidia.com>
Co-authored-by: Robert (Bobby) Evans <bobby@apache.org>
  • Loading branch information
rwlee and revans2 authored Jan 21, 2021
1 parent a856c11 commit 39d8adc
Show file tree
Hide file tree
Showing 5 changed files with 256 additions and 19 deletions.
2 changes: 2 additions & 0 deletions docs/configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ Name | SQL Function(s) | Description | Default Value | Notes
<a name="sql.expression.Cast"></a>spark.rapids.sql.expression.Cast|`timestamp`, `tinyint`, `binary`, `float`, `smallint`, `string`, `decimal`, `double`, `boolean`, `cast`, `date`, `int`, `bigint`|Convert a column of one type of data into another type|true|None|
<a name="sql.expression.Cbrt"></a>spark.rapids.sql.expression.Cbrt|`cbrt`|Cube root|true|None|
<a name="sql.expression.Ceil"></a>spark.rapids.sql.expression.Ceil|`ceiling`, `ceil`|Ceiling of a number|true|None|
<a name="sql.expression.CheckOverflow"></a>spark.rapids.sql.expression.CheckOverflow| |CheckOverflow after arithmetic operations between DecimalType data|true|None|
<a name="sql.expression.Coalesce"></a>spark.rapids.sql.expression.Coalesce|`coalesce`|Returns the first non-null argument if exists. Otherwise, null|true|None|
<a name="sql.expression.Concat"></a>spark.rapids.sql.expression.Concat|`concat`|String concatenate NO separator|true|None|
<a name="sql.expression.Contains"></a>spark.rapids.sql.expression.Contains| |Contains|true|None|
Expand Down Expand Up @@ -193,6 +194,7 @@ Name | SQL Function(s) | Description | Default Value | Notes
<a name="sql.expression.Or"></a>spark.rapids.sql.expression.Or|`or`|Logical OR|true|None|
<a name="sql.expression.Pmod"></a>spark.rapids.sql.expression.Pmod|`pmod`|Pmod|true|None|
<a name="sql.expression.Pow"></a>spark.rapids.sql.expression.Pow|`pow`, `power`|lhs ^ rhs|true|None|
<a name="sql.expression.PromotePrecision"></a>spark.rapids.sql.expression.PromotePrecision| |PromotePrecision before arithmetic operations between DecimalType data|true|None|
<a name="sql.expression.PythonUDF"></a>spark.rapids.sql.expression.PythonUDF| |UDF run in an external python process. Does not actually run on the GPU, but the transfer of data to/from it can be accelerated.|true|None|
<a name="sql.expression.Quarter"></a>spark.rapids.sql.expression.Quarter|`quarter`|Returns the quarter of the year for date, in the range 1 to 4|true|None|
<a name="sql.expression.Rand"></a>spark.rapids.sql.expression.Rand|`random`, `rand`|Generate a random column with i.i.d. uniformly distributed values in [0, 1)|true|None|
Expand Down
192 changes: 186 additions & 6 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -1068,7 +1068,7 @@ Accelerator support is described below.
<td> </td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
<td>S*</td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
Expand All @@ -1089,7 +1089,7 @@ Accelerator support is described below.
<td> </td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
<td>S*</td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
Expand All @@ -1110,7 +1110,7 @@ Accelerator support is described below.
<td> </td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
<td>S*</td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
Expand Down Expand Up @@ -2834,6 +2834,96 @@ Accelerator support is described below.
<td> </td>
</tr>
<tr>
<td rowSpan="4">CheckOverflow</td>
<td rowSpan="4"> </td>
<td rowSpan="4">CheckOverflow after arithmetic operations between DecimalType data</td>
<td rowSpan="4">None</td>
<td rowSpan="2">project</td>
<td>input</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td>S*</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>result</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td>S*</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td rowSpan="2">lambda</td>
<td>input</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>result</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td rowSpan="4">Coalesce</td>
<td rowSpan="4">`coalesce`</td>
<td rowSpan="4">Returns the first non-null argument if exists. Otherwise, null</td>
Expand Down Expand Up @@ -9859,6 +9949,96 @@ Accelerator support is described below.
<td> </td>
</tr>
<tr>
<td rowSpan="4">PromotePrecision</td>
<td rowSpan="4"> </td>
<td rowSpan="4">PromotePrecision before arithmetic operations between DecimalType data</td>
<td rowSpan="4">None</td>
<td rowSpan="2">project</td>
<td>input</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td>S*</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>result</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td>S*</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td rowSpan="2">lambda</td>
<td>input</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>result</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td rowSpan="8">PythonUDF</td>
<td rowSpan="8"> </td>
<td rowSpan="8">UDF run in an external python process. Does not actually run on the GPU, but the transfer of data to/from it can be accelerated.</td>
Expand Down Expand Up @@ -13540,7 +13720,7 @@ Accelerator support is described below.
<td> </td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
<td>S*</td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
Expand All @@ -13561,7 +13741,7 @@ Accelerator support is described below.
<td> </td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
<td>S*</td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
Expand All @@ -13582,7 +13762,7 @@ Accelerator support is described below.
<td> </td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
<td>S*</td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
Expand Down
13 changes: 9 additions & 4 deletions integration_tests/src/main/python/arithmetic_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
from spark_session import with_spark_session, is_before_spark_310
import pyspark.sql.functions as f

@pytest.mark.parametrize('data_gen', numeric_gens, ids=idfn)
decimal_gens_not_max_prec = [decimal_gen_neg_scale, decimal_gen_scale_precision,
decimal_gen_same_scale_precision, decimal_gen_64bit]

@pytest.mark.parametrize('data_gen', numeric_gens + decimal_gens_not_max_prec, ids=idfn)
def test_addition(data_gen):
data_type = data_gen.data_type
assert_gpu_and_cpu_are_equal_collect(
Expand All @@ -30,9 +33,10 @@ def test_addition(data_gen):
f.lit(-12).cast(data_type) + f.col('b'),
f.lit(None).cast(data_type) + f.col('a'),
f.col('b') + f.lit(None).cast(data_type),
f.col('a') + f.col('b')))
f.col('a') + f.col('b')),
conf=allow_negative_scale_of_decimal_conf)

@pytest.mark.parametrize('data_gen', numeric_gens, ids=idfn)
@pytest.mark.parametrize('data_gen', numeric_gens + decimal_gens_not_max_prec, ids=idfn)
def test_subtraction(data_gen):
data_type = data_gen.data_type
assert_gpu_and_cpu_are_equal_collect(
Expand All @@ -41,7 +45,8 @@ def test_subtraction(data_gen):
f.lit(-12).cast(data_type) - f.col('b'),
f.lit(None).cast(data_type) - f.col('a'),
f.col('b') - f.lit(None).cast(data_type),
f.col('a') - f.col('b')))
f.col('a') - f.col('b')),
conf=allow_negative_scale_of_decimal_conf)

@pytest.mark.parametrize('data_gen', numeric_gens, ids=idfn)
def test_multiplication(data_gen):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.datasources.text.TextFileFormat
import org.apache.spark.sql.execution.datasources.v2.{AlterNamespaceSetPropertiesExec, AlterTableExec, AtomicReplaceTableExec, BatchScanExec, CreateNamespaceExec, CreateTableExec, DeleteFromTableExec, DescribeNamespaceExec, DescribeTableExec, DropNamespaceExec, DropTableExec, RefreshTableExec, RenameTableExec, ReplaceTableExec, SetCatalogAndNamespaceExec, ShowCurrentNamespaceExec, ShowNamespacesExec, ShowTablePropertiesExec, ShowTablesExec}
import org.apache.spark.sql.execution.datasources.v2.csv.CSVScan
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.execution.python._
Expand Down Expand Up @@ -706,6 +705,14 @@ object GpuOverrides {
}
}
}),
expr[PromotePrecision](
"PromotePrecision before arithmetic operations between DecimalType data",
ExprChecks.unaryProjectNotLambdaInputMatchesOutput(TypeSig.DECIMAL, TypeSig.DECIMAL),
(a, conf, p, r) => new PromotePrecisionExprMeta(a, conf, p, r)),
expr[CheckOverflow](
"CheckOverflow after arithmetic operations between DecimalType data",
ExprChecks.unaryProjectNotLambdaInputMatchesOutput(TypeSig.DECIMAL, TypeSig.DECIMAL),
(a, conf, p, r) => new CheckOverflowExprMeta(a, conf, p, r)),
expr[ToDegrees](
"Converts radians to degrees",
ExprChecks.mathUnary,
Expand Down Expand Up @@ -1377,19 +1384,19 @@ object GpuOverrides {
expr[Add](
"Addition",
ExprChecks.binaryProjectNotLambda(
TypeSig.integral + TypeSig.fp, TypeSig.numericAndInterval,
("lhs", TypeSig.integral + TypeSig.fp, TypeSig.numericAndInterval),
("rhs", TypeSig.integral + TypeSig.fp, TypeSig.numericAndInterval)),
TypeSig.numeric, TypeSig.numericAndInterval,
("lhs", TypeSig.numeric, TypeSig.numericAndInterval),
("rhs", TypeSig.numeric, TypeSig.numericAndInterval)),
(a, conf, p, r) => new BinaryExprMeta[Add](a, conf, p, r) {
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
GpuAdd(lhs, rhs)
}),
expr[Subtract](
"Subtraction",
ExprChecks.binaryProjectNotLambda(
TypeSig.integral + TypeSig.fp, TypeSig.numericAndInterval,
("lhs", TypeSig.integral + TypeSig.fp, TypeSig.numericAndInterval),
("rhs", TypeSig.integral + TypeSig.fp, TypeSig.numericAndInterval)),
TypeSig.numeric, TypeSig.numericAndInterval,
("lhs", TypeSig.numeric, TypeSig.numericAndInterval),
("rhs", TypeSig.numeric, TypeSig.numericAndInterval)),
(a, conf, p, r) => new BinaryExprMeta[Subtract](a, conf, p, r) {
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
GpuSubtract(lhs, rhs)
Expand Down
Loading

0 comments on commit 39d8adc

Please sign in to comment.