Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decimal32 support #1717

Merged
merged 37 commits into from
Apr 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
3b67957
Add support for Decimal32
razajafri Jan 28, 2021
a9388e8
fixed unary_minus
razajafri Feb 1, 2021
1c31d30
unscaledLong fix
razajafri Feb 1, 2021
8832ffa
More support for Decimal32
razajafri Feb 3, 2021
7db8137
implicit for casting dec32todec64
razajafri Feb 5, 2021
b23754c
cleanup
razajafri Feb 5, 2021
2c88a75
refactored castDecimalToDecimal
razajafri Feb 6, 2021
15b8ff5
fixed legacy decimal read for non-nested
razajafri Feb 9, 2021
6060aea
fixed casting and added more tests
razajafri Feb 9, 2021
e486bd1
added nested tests for reading legacy decimals
razajafri Feb 12, 2021
cc82c29
removed implicit
razajafri Feb 12, 2021
8389019
struct working
razajafri Feb 19, 2021
e573a40
Lists working
razajafri Feb 21, 2021
b538ee6
divide not working
razajafri Feb 21, 2021
a8b3b0a
cleanup
razajafri Feb 22, 2021
26fa64c
division working but problem with casting
razajafri Feb 24, 2021
38da4e3
moved div code to GpuModLike
razajafri Feb 26, 2021
8702bef
code cleanup
razajafri Feb 26, 2021
6e8310e
some more fixes
razajafri Feb 27, 2021
688ab2e
some more fixes
razajafri Feb 27, 2021
2216fe6
addressed review comments
razajafri Mar 2, 2021
ce9c00f
added more comments
razajafri Mar 2, 2021
7a3cc95
Merge remote-tracking branch 'origin/branch-0.5' into decimal32
razajafri Mar 19, 2021
4534a6c
park
razajafri Mar 20, 2021
77c9b75
Merge remote-tracking branch 'origin/branch-0.5' into decimal32
razajafri Mar 24, 2021
caa8e85
properly cast scalar
razajafri Mar 24, 2021
61fd3fc
Merge remote-tracking branch 'origin/branch-0.5' into decimal32
razajafri Mar 30, 2021
a6d42c5
fixed gpu metric
razajafri Mar 30, 2021
390016d
Merge remote-tracking branch 'origin/branch-0.5' into decimal32
razajafri Mar 30, 2021
900c919
upmerged
razajafri Mar 30, 2021
6fd98c9
fixed castFloatsToDecimals to pick the right precision
razajafri Mar 31, 2021
e14b3a7
addressed review comments
razajafri Mar 31, 2021
95448c3
Fixed memory leak
razajafri Apr 2, 2021
4101856
Merge remote-tracking branch 'origin/branch-0.5' into decimal32
razajafri Apr 3, 2021
1676c87
removed length restriction
razajafri Apr 3, 2021
55b32e9
Merge remote-tracking branch 'origin/branch-0.5' into decimal32
razajafri Apr 12, 2021
0dfc529
fixed test failure due to upmerge
razajafri Apr 12, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions integration_tests/src/main/python/parquet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,9 @@ def test_ts_read_fails_datetime_legacy(gen, spark_tmp_path, ts_write, ts_rebase,
lambda spark : readParquetCatchException(spark, data_path),
conf=all_confs)


@pytest.mark.parametrize('parquet_gens', [decimal_gens], ids=idfn)
@pytest.mark.parametrize('parquet_gens', [decimal_gens,
[ArrayGen(DecimalGen(7,2), max_length=10)],
[StructGen([['child0', DecimalGen(7, 2)]])]], ids=idfn)
@pytest.mark.parametrize('read_func', [read_parquet_df, read_parquet_sql])
@pytest.mark.parametrize('reader_confs', reader_opt_confs)
@pytest.mark.parametrize('v1_enabled_list', ["", "parquet"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -390,12 +390,12 @@ class ParquetCachedBatchSerializer extends CachedBatchSerializer with Arm {
table: Table,
schema: StructType): ParquetBufferConsumer = {
val buffer = new ParquetBufferConsumer(table.getRowCount.toInt)
val options = ParquetWriterOptions.builder()
.withDecimalPrecisions(GpuParquetFileFormat.getPrecisionList(schema):_*)
val builder = ParquetWriterOptions.builder()
.withDecimalPrecisions(GpuParquetFileFormat.getPrecisionList(schema): _*)
.withStatisticsFrequency(StatisticsFrequency.ROWGROUP)
.withTimestampInt96(false)
.build()
withResource(Table.writeParquetChunked(options, buffer)) { writer =>
schema.fields.indices.foreach(index => builder.withColumnNames(s"_col$index"))
revans2 marked this conversation as resolved.
Show resolved Hide resolved
withResource(Table.writeParquetChunked(builder.build(), buffer)) { writer =>
writer.write(table)
}
buffer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package org.apache.spark.sql.rapids.shims.spark311

import com.nvidia.spark.rapids.GpuExec
import com.nvidia.spark.rapids.{GpuExec, GpuMetric}
import com.nvidia.spark.rapids.shims.spark311.ParquetCachedBatchSerializer

import org.apache.spark.rdd.RDD
Expand All @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan}
import org.apache.spark.sql.execution.columnar.InMemoryRelation
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.vectorized.ColumnarBatch

case class GpuInMemoryTableScanExec(
Expand Down Expand Up @@ -54,7 +53,7 @@ case class GpuInMemoryTableScanExec(
relation.cacheBuilder.serializer.vectorTypes(attributes, conf)

private lazy val columnarInputRDD: RDD[ColumnarBatch] = {
val numOutputRows = longMetric("numOutputRows")
val numOutputRows = gpuLongMetric(GpuMetric.NUM_OUTPUT_ROWS)
val buffers = filteredCachedBatches()
relation.cacheBuilder.serializer.asInstanceOf[ParquetCachedBatchSerializer]
.gpuConvertCachedBatchToColumnarBatch(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ private static String hexString(byte[] bytes) {
public static synchronized void debug(String name, HostColumnVectorCore hostCol) {
DType type = hostCol.getType();
System.err.println("COLUMN " + name + " - " + type);
if (type.getTypeId() == DType.DTypeEnum.DECIMAL64) {
if (type.isDecimalType()) {
for (int i = 0; i < hostCol.getRowCount(); i++) {
if (hostCol.isNull(i)) {
System.err.println(i + " NULL");
Expand Down Expand Up @@ -472,8 +472,7 @@ private static DType toRapidsOrNull(DataType type) {
if (dt.precision() > DType.DECIMAL64_MAX_PRECISION) {
return null;
} else {
// Map all DecimalType to DECIMAL64, in case of underlying DType transaction.
return DType.create(DType.DTypeEnum.DECIMAL64, -dt.scale());
return DecimalUtil.createCudfDecimal(dt.precision(), dt.scale());
}
}
return null;
Expand Down Expand Up @@ -864,7 +863,6 @@ public static int[] toIntArray(ai.rapids.cudf.ColumnVector vec) {
*/
GpuColumnVector(DataType type, ai.rapids.cudf.ColumnVector cudfCv) {
super(type);
// TODO need some checks to be sure everything matches
this.cudfCv = cudfCv;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,16 @@ public final ColumnarMap getMap(int ordinal) {
@Override
public final Decimal getDecimal(int rowId, int precision, int scale) {
assert precision <= DType.DECIMAL64_MAX_PRECISION : "Assert " + precision + " <= DECIMAL64_MAX_PRECISION(" + DType.DECIMAL64_MAX_PRECISION + ")";
assert cudfCv.getType().getTypeId() == DType.DTypeEnum.DECIMAL64: "Assert DType to be DECIMAL64";
assert scale == -cudfCv.getType().getScale() :
"Assert fetch decimal with its original scale " + scale + " expected " + (-cudfCv.getType().getScale());
return Decimal.createUnsafe(cudfCv.getLong(rowId), precision, scale);
if (precision <= Decimal.MAX_INT_DIGITS()) {
assert cudfCv.getType().getTypeId() == DType.DTypeEnum.DECIMAL32 : "type should be DECIMAL32";
return Decimal.createUnsafe(cudfCv.getInt(rowId), precision, scale);
} else {
assert cudfCv.getType().getTypeId() == DType.DTypeEnum.DECIMAL64 : "type should be DECIMAL64";
return Decimal.createUnsafe(cudfCv.getLong(rowId), precision, scale);
}

}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,9 @@ public Decimal getDecimal(int ordinal, int precision, int scale) {
if (isNullAt(ordinal)) {
return null;
}
// TODO when DECIMAL32 is supported a special case will need to be added here
if (precision <= Decimal.MAX_LONG_DIGITS()) {
if (precision <= Decimal.MAX_INT_DIGITS()) {
return Decimal.createUnsafe(getInt(ordinal), precision, scale);
} else if (precision <= Decimal.MAX_LONG_DIGITS()) {
return Decimal.createUnsafe(getLong(ordinal), precision, scale);
} else {
throw new IllegalArgumentException("NOT IMPLEMENTED YET");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids

import ai.rapids.cudf.DType

import org.apache.spark.sql.types.{DataType, Decimal, DecimalType}

object DecimalUtil {

def createCudfDecimal(precision: Int, scale: Int): DType = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we have some comments here? I would like it clean that the input precision and scale should be what Spark expects and this will convert it into whatever CUDF expects.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will fix this in one of my PRs

if (precision <= Decimal.MAX_INT_DIGITS) {
DType.create(DType.DTypeEnum.DECIMAL32, -scale)
} else {
DType.create(DType.DTypeEnum.DECIMAL64, -scale)
}
}

/**
* Return the size in bytes of the Fixed-width data types.
* WARNING: Do not use this method for variable-width data types
*/
private[rapids] def getDataTypeSize(dt: DataType): Int = {
revans2 marked this conversation as resolved.
Show resolved Hide resolved
dt match {
case d: DecimalType if d.precision <= Decimal.MAX_INT_DIGITS => 4
case t => t.defaultSize
}
}
}
Loading