Skip to content

Commit

Permalink
Promote to decimal with larger scale and precision
Browse files Browse the repository at this point in the history
  • Loading branch information
johanl-db committed Dec 27, 2023
1 parent b106f80 commit 2966027
Show file tree
Hide file tree
Showing 4 changed files with 262 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@
import org.apache.spark.sql.execution.vectorized.WritableColumnVector;
import org.apache.spark.sql.types.*;

import java.math.BigDecimal;
import java.math.BigInteger;
import java.math.RoundingMode;
import java.time.ZoneId;
import java.time.ZoneOffset;
import java.util.Arrays;
Expand Down Expand Up @@ -108,6 +110,8 @@ public ParquetVectorUpdater getUpdater(ColumnDescriptor descriptor, DataType spa
}
} else if (sparkType instanceof YearMonthIntervalType) {
return new IntegerUpdater();
} else if (canReadAsDecimal(descriptor, sparkType)) {
return new IntegerToDecimalUpdater(descriptor, (DecimalType) sparkType);
}
}
case INT64 -> {
Expand Down Expand Up @@ -153,6 +157,8 @@ public ParquetVectorUpdater getUpdater(ColumnDescriptor descriptor, DataType spa
return new LongAsMicrosUpdater();
} else if (sparkType instanceof DayTimeIntervalType) {
return new LongUpdater();
} else if (canReadAsDecimal(descriptor, sparkType)) {
return new LongToDecimalUpdater(descriptor, (DecimalType) sparkType);
}
}
case FLOAT -> {
Expand Down Expand Up @@ -194,6 +200,8 @@ public ParquetVectorUpdater getUpdater(ColumnDescriptor descriptor, DataType spa
if (sparkType == DataTypes.StringType || sparkType == DataTypes.BinaryType ||
canReadAsBinaryDecimal(descriptor, sparkType)) {
return new BinaryUpdater();
} else if (canReadAsDecimal(descriptor, sparkType)) {
return new BinaryToDecimalUpdater(descriptor, (DecimalType) sparkType);
}
}
case FIXED_LEN_BYTE_ARRAY -> {
Expand All @@ -206,6 +214,8 @@ public ParquetVectorUpdater getUpdater(ColumnDescriptor descriptor, DataType spa
return new FixedLenByteArrayUpdater(arrayLen);
} else if (sparkType == DataTypes.BinaryType) {
return new FixedLenByteArrayUpdater(arrayLen);
} else if (canReadAsDecimal(descriptor, sparkType)) {
return new FixedLenByteArrayToDecimalUpdater(descriptor, (DecimalType) sparkType);
}
}
default -> {}
Expand Down Expand Up @@ -1358,6 +1368,180 @@ public void decodeSingleDictionaryId(
}
}

private abstract static class DecimalUpdater implements ParquetVectorUpdater {

private final DecimalType sparkType;

DecimalUpdater(DecimalType sparkType) {
this.sparkType = sparkType;
}

@Override
public void readValues(
int total,
int offset,
WritableColumnVector values,
VectorizedValuesReader valuesReader) {
for (int i = 0; i < total; i++) {
readValue(offset + i, values, valuesReader);
}
}

protected void writeDecimal(int offset, WritableColumnVector values, BigDecimal decimal) {
BigDecimal scaledDecimal = decimal.setScale(sparkType.scale(), RoundingMode.UNNECESSARY);
if (DecimalType.is32BitDecimalType(sparkType)) {
values.putInt(offset, scaledDecimal.unscaledValue().intValue());
} else if (DecimalType.is64BitDecimalType(sparkType)) {
values.putLong(offset, scaledDecimal.unscaledValue().longValue());
} else {
values.putByteArray(offset, scaledDecimal.unscaledValue().toByteArray());
}
}
}

private static class IntegerToDecimalUpdater extends DecimalUpdater {
private final int parquetScale;

IntegerToDecimalUpdater(ColumnDescriptor descriptor, DecimalType sparkType) {
super(sparkType);
LogicalTypeAnnotation typeAnnotation = descriptor.getPrimitiveType().getLogicalTypeAnnotation();
this.parquetScale = ((DecimalLogicalTypeAnnotation) typeAnnotation).getScale();
}

@Override
public void skipValues(int total, VectorizedValuesReader valuesReader) {
valuesReader.skipIntegers(total);
}

@Override
public void readValue(
int offset,
WritableColumnVector values,
VectorizedValuesReader valuesReader) {
BigDecimal decimal = BigDecimal.valueOf(valuesReader.readInteger(), parquetScale);
writeDecimal(offset, values, decimal);
}

@Override
public void decodeSingleDictionaryId(
int offset,
WritableColumnVector values,
WritableColumnVector dictionaryIds,
Dictionary dictionary) {
BigDecimal decimal = BigDecimal.valueOf(dictionary.decodeToInt(dictionaryIds.getDictId(offset)), parquetScale);
writeDecimal(offset, values, decimal);
}
}

private static class LongToDecimalUpdater extends DecimalUpdater {
private final int parquetScale;

LongToDecimalUpdater(ColumnDescriptor descriptor, DecimalType sparkType) {
super(sparkType);
LogicalTypeAnnotation typeAnnotation = descriptor.getPrimitiveType().getLogicalTypeAnnotation();
this.parquetScale = ((DecimalLogicalTypeAnnotation) typeAnnotation).getScale();
}

@Override
public void skipValues(int total, VectorizedValuesReader valuesReader) {
valuesReader.skipLongs(total);
}

@Override
public void readValue(
int offset,
WritableColumnVector values,
VectorizedValuesReader valuesReader) {
BigDecimal decimal = BigDecimal.valueOf(valuesReader.readLong(), parquetScale);
writeDecimal(offset, values, decimal);
}

@Override
public void decodeSingleDictionaryId(
int offset,
WritableColumnVector values,
WritableColumnVector dictionaryIds,
Dictionary dictionary) {
BigDecimal decimal = BigDecimal.valueOf(dictionary.decodeToLong(dictionaryIds.getDictId(offset)), parquetScale);
writeDecimal(offset, values, decimal);
}
}

private static class BinaryToDecimalUpdater extends DecimalUpdater {
private final int parquetScale;

BinaryToDecimalUpdater(ColumnDescriptor descriptor, DecimalType sparkType) {
super(sparkType);
LogicalTypeAnnotation typeAnnotation = descriptor.getPrimitiveType().getLogicalTypeAnnotation();
this.parquetScale = ((DecimalLogicalTypeAnnotation) typeAnnotation).getScale();
}

@Override
public void skipValues(int total, VectorizedValuesReader valuesReader) {
valuesReader.skipBinary(total);
}

@Override
public void readValue(
int offset,
WritableColumnVector values,
VectorizedValuesReader valuesReader) {
valuesReader.readBinary(1, values, offset);
BigInteger value = new BigInteger(values.getBinary(offset));
BigDecimal decimal = new BigDecimal(value, parquetScale);
writeDecimal(offset, values, decimal);
}

@Override
public void decodeSingleDictionaryId(
int offset,
WritableColumnVector values,
WritableColumnVector dictionaryIds,
Dictionary dictionary) {
BigInteger value = new BigInteger(dictionary.decodeToBinary(dictionaryIds.getDictId(offset)).getBytes());
BigDecimal decimal = new BigDecimal(value, parquetScale);
writeDecimal(offset, values, decimal);
}
}

private static class FixedLenByteArrayToDecimalUpdater extends DecimalUpdater {
private final int parquetScale;
private final int arrayLen;

FixedLenByteArrayToDecimalUpdater(ColumnDescriptor descriptor, DecimalType sparkType) {
super(sparkType);
LogicalTypeAnnotation typeAnnotation = descriptor.getPrimitiveType().getLogicalTypeAnnotation();
this.parquetScale = ((DecimalLogicalTypeAnnotation) typeAnnotation).getScale();
this.arrayLen = descriptor.getPrimitiveType().getTypeLength();
}

@Override
public void skipValues(int total, VectorizedValuesReader valuesReader) {
valuesReader.skipFixedLenByteArray(total, arrayLen);
}

@Override
public void readValue(
int offset,
WritableColumnVector values,
VectorizedValuesReader valuesReader) {
BigInteger value = new BigInteger(valuesReader.readBinary(arrayLen).getBytes());
BigDecimal decimal = new BigDecimal(value, this.parquetScale);
writeDecimal(offset, values, decimal);
}

@Override
public void decodeSingleDictionaryId(
int offset,
WritableColumnVector values,
WritableColumnVector dictionaryIds,
Dictionary dictionary) {
BigInteger value = new BigInteger(dictionary.decodeToBinary(dictionaryIds.getDictId(offset)).getBytes());
BigDecimal decimal = new BigDecimal(value, this.parquetScale);
writeDecimal(offset, values, decimal);
}
}

private static int rebaseDays(int julianDays, final boolean failIfRebase) {
if (failIfRebase) {
if (julianDays < RebaseDateTime.lastSwitchJulianDay()) {
Expand Down Expand Up @@ -1418,16 +1602,21 @@ private SchemaColumnConvertNotSupportedException constructConvertNotSupportedExc

private static boolean canReadAsIntDecimal(ColumnDescriptor descriptor, DataType dt) {
if (!DecimalType.is32BitDecimalType(dt)) return false;
return isDecimalTypeMatched(descriptor, dt);
return isDecimalTypeMatched(descriptor, dt) && isSameDecimalScale(descriptor, dt);
}

private static boolean canReadAsLongDecimal(ColumnDescriptor descriptor, DataType dt) {
if (!DecimalType.is64BitDecimalType(dt)) return false;
return isDecimalTypeMatched(descriptor, dt);
return isDecimalTypeMatched(descriptor, dt) && isSameDecimalScale(descriptor, dt);
}

private static boolean canReadAsBinaryDecimal(ColumnDescriptor descriptor, DataType dt) {
if (!DecimalType.isByteArrayDecimalType(dt)) return false;
return isDecimalTypeMatched(descriptor, dt) && isSameDecimalScale(descriptor, dt);
}

private static boolean canReadAsDecimal(ColumnDescriptor descriptor, DataType dt) {
if (!(dt instanceof DecimalType)) return false;
return isDecimalTypeMatched(descriptor, dt);
}

Expand All @@ -1447,11 +1636,24 @@ private static boolean isDecimalTypeMatched(ColumnDescriptor descriptor, DataTyp
DecimalType d = (DecimalType) dt;
LogicalTypeAnnotation typeAnnotation = descriptor.getPrimitiveType().getLogicalTypeAnnotation();
if (typeAnnotation instanceof DecimalLogicalTypeAnnotation decimalType) {
// It's OK if the required decimal precision is larger than or equal to the physical decimal
// precision in the Parquet metadata, as long as the decimal scale is the same.
return decimalType.getPrecision() <= d.precision() && decimalType.getScale() == d.scale();
// If the required scale is larger than or equal to the physical decimal scale in the Parquet
// metadata, we can upscale the value as long as the precision also increases by as much so
// that there is no loss of precision.
return decimalType.getPrecision() <= d.precision() &&
(decimalType.getPrecision() - decimalType.getScale()) <= (d.precision() - d.scale());
}
return false;
}

private static boolean isSameDecimalScale(ColumnDescriptor descriptor, DataType dt) {
DecimalType d = (DecimalType) dt;
LogicalTypeAnnotation typeAnnotation = descriptor.getPrimitiveType().getLogicalTypeAnnotation();
if (typeAnnotation instanceof DecimalLogicalTypeAnnotation) {
DecimalLogicalTypeAnnotation decimalType = (DecimalLogicalTypeAnnotation) typeAnnotation;
return decimalType.getScale() == d.scale();
}
return false;
}

}

Original file line number Diff line number Diff line change
Expand Up @@ -152,32 +152,50 @@ private boolean isLazyDecodingSupported(
switch (typeName) {
case INT32: {
boolean isDate = logicalTypeAnnotation instanceof DateLogicalTypeAnnotation;
boolean needsUpcast = sparkType == LongType || (isDate && sparkType == TimestampNTZType) ||
!DecimalType.is32BitDecimalType(sparkType);
boolean isDecimal = logicalTypeAnnotation instanceof DecimalLogicalTypeAnnotation;
boolean needsUpcast = sparkType == LongType || sparkType == DoubleType ||
(isDate && sparkType == TimestampNTZType) ||
(isDecimal && !DecimalType.is32BitDecimalType(sparkType));
boolean needsRebase = logicalTypeAnnotation instanceof DateLogicalTypeAnnotation &&
!"CORRECTED".equals(datetimeRebaseMode);
isSupported = !needsUpcast && !needsRebase;
isSupported = !needsUpcast && !needsRebase && !needsDecimalScaleRebase(sparkType);
break;
}
case INT64: {
boolean needsUpcast = !DecimalType.is64BitDecimalType(sparkType) ||
boolean isDecimal = logicalTypeAnnotation instanceof DecimalLogicalTypeAnnotation;
boolean needsUpcast = (isDecimal && !DecimalType.is64BitDecimalType(sparkType)) ||
updaterFactory.isTimestampTypeMatched(TimeUnit.MILLIS);
boolean needsRebase = updaterFactory.isTimestampTypeMatched(TimeUnit.MICROS) &&
!"CORRECTED".equals(datetimeRebaseMode);
isSupported = !needsUpcast && !needsRebase;
isSupported = !needsUpcast && !needsRebase && !needsDecimalScaleRebase(sparkType);
break;
}
case FLOAT:
isSupported = sparkType == FloatType;
break;
case DOUBLE:
case BINARY:
isSupported = true;
break;
case BINARY:
isSupported = !needsDecimalScaleRebase(sparkType);
break;
}
return isSupported;
}

/**
* Returns whether the Parquet type of this column and the given spark type are two decimal types
* with different scale.
*/
private boolean needsDecimalScaleRebase(DataType sparkType) {
LogicalTypeAnnotation typeAnnotation = descriptor.getPrimitiveType().getLogicalTypeAnnotation();
if (!(typeAnnotation instanceof DecimalLogicalTypeAnnotation)) return false;
if (!(sparkType instanceof DecimalType)) return false;
DecimalLogicalTypeAnnotation parquetDecimal = (DecimalLogicalTypeAnnotation) typeAnnotation;
DecimalType sparkDecimal = (DecimalType) sparkType;
return parquetDecimal.getScale() != sparkDecimal.scale();
}

/**
* Reads `total` rows from this columnReader into column.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1049,7 +1049,9 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS
}

withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true") {
Seq("a DECIMAL(3, 2)", "b DECIMAL(18, 1)", "c DECIMAL(37, 1)").foreach { schema =>
val schema1 = "a DECIMAL(3, 2), b DECIMAL(18, 3), c DECIMAL(37, 3)"
checkAnswer(readParquet(schema1, path), df)
Seq("a DECIMAL(3, 0)", "b DECIMAL(18, 1)", "c DECIMAL(37, 1)").foreach { schema =>
val e = intercept[SparkException] {
readParquet(schema, path).collect()
}.getCause.getCause
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,10 +209,37 @@ class ParquetTypeWideningSuite
toType = DecimalType(toPrecision, 2),
expectError = fromPrecision > toPrecision &&
// parquet-mr allows reading decimals into a smaller precision decimal type without
// checking for overflows. See test below.
// checking for overflows. See test below checking for the overflow case in parquet-mr.
spark.conf.get(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key).toBoolean)
}

for {
((fromPrecision, fromScale), (toPrecision, toScale)) <-
// Test widening and narrowing precision and scale by the same amount between the same and
// different decimal physical parquet types:
// - INT32: precisions 5, 7
// - INT64: precisions 10, 12
// - FIXED_LEN_BYTE_ARRAY: precisions 20, 22
Seq((5, 2) -> (7, 4), (5, 2) -> (10, 7), (5, 2) -> (20, 17), (10, 2) -> (12, 4),
(10, 2) -> (20, 12), (20, 2) -> (22, 4)) ++
Seq((7, 4) -> (5, 2), (10, 7) -> (5, 2), (20, 17) -> (5, 2), (12, 4) -> (10, 2),
(20, 17) -> (10, 2), (22, 4) -> (20, 2))
}
test(s"parquet decimal precision and scale change Decimal($fromPrecision, $fromScale) -> " +
s"Decimal($toPrecision, $toScale)"
) {
checkAllParquetReaders(
values = Seq("1.23", "10.34"),
fromType = DecimalType(fromPrecision, fromScale),
toType = DecimalType(toPrecision, toScale),
expectError =
// parquet-mr allows reading decimals into a smaller precision/scale decimal type without
// checking for overflows. See test below checking for the overflow case in parquet-mr.
fromPrecision > toPrecision &&
spark.conf.get(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key).toBoolean
)
}

test("parquet decimal type change Decimal(5, 2) -> Decimal(3, 2) overflows with parquet-mr") {
withTempDir { dir =>
withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
Expand Down

0 comments on commit 2966027

Please sign in to comment.