Skip to content

Commit

Permalink
Address comments + add test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
johanl-db committed Dec 19, 2023
1 parent dc8b489 commit dd1975c
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.parquet.io.api.Binary;
import org.apache.parquet.schema.LogicalTypeAnnotation;
import org.apache.parquet.schema.LogicalTypeAnnotation.IntLogicalTypeAnnotation;
import org.apache.parquet.schema.LogicalTypeAnnotation.DateLogicalTypeAnnotation;
import org.apache.parquet.schema.LogicalTypeAnnotation.DecimalLogicalTypeAnnotation;
import org.apache.parquet.schema.LogicalTypeAnnotation.TimestampLogicalTypeAnnotation;
import org.apache.parquet.schema.PrimitiveType;
Expand Down Expand Up @@ -98,7 +99,7 @@ public ParquetVectorUpdater getUpdater(ColumnDescriptor descriptor, DataType spa
boolean failIfRebase = "EXCEPTION".equals(datetimeRebaseMode);
return new IntegerWithRebaseUpdater(failIfRebase);
}
} else if (sparkType == DataTypes.TimestampNTZType) {
} else if (sparkType == DataTypes.TimestampNTZType && isDateTypeMatched(descriptor)) {
if ("CORRECTED".equals(datetimeRebaseMode)) {
return new DateToTimestampNTZUpdater();
} else {
Expand Down Expand Up @@ -376,8 +377,7 @@ public void readValues(
WritableColumnVector values,
VectorizedValuesReader valuesReader) {
for (int i = 0; i < total; ++i) {
long days = DateTimeUtils.daysToMicros(valuesReader.readInteger(), ZoneOffset.UTC);
values.putLong(offset + i, days);
readValue(offset + i, values, valuesReader);
}
}

Expand Down Expand Up @@ -420,8 +420,7 @@ public void readValues(
WritableColumnVector values,
VectorizedValuesReader valuesReader) {
for (int i = 0; i < total; ++i) {
int rebasedDays = rebaseDays(valuesReader.readInteger(), failIfRebase);
values.putLong(offset + i, DateTimeUtils.daysToMicros(rebasedDays, ZoneOffset.UTC));
readValue(offset + i, values, valuesReader);
}
}

Expand Down Expand Up @@ -1436,6 +1435,11 @@ private static boolean isTimestamp(DataType dt) {
return dt == DataTypes.TimestampType || dt == DataTypes.TimestampNTZType;
}

boolean isDateTypeMatched(ColumnDescriptor descriptor) {
LogicalTypeAnnotation typeAnnotation = descriptor.getPrimitiveType().getLogicalTypeAnnotation();
return typeAnnotation instanceof DateLogicalTypeAnnotation;
}

private static boolean isDecimalTypeMatched(ColumnDescriptor descriptor, DataType dt) {
DecimalType d = (DecimalType) dt;
LogicalTypeAnnotation typeAnnotation = descriptor.getPrimitiveType().getLogicalTypeAnnotation();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,42 +145,37 @@ public VectorizedColumnReader(

private boolean isLazyDecodingSupported(PrimitiveType.PrimitiveTypeName typeName,
DataType sparkType) {
// Don't use lazy dictionary decoding if the column needs extra processing: upcasting or date /
// decimal scale rebasing.
return switch (typeName) {
case INT32 -> {
boolean isSupported = false;
// Don't use lazy dictionary decoding if the column needs extra processing: upcasting or date
// rebasing.
switch (typeName) {
case INT32: {
boolean needsUpcast = sparkType == LongType || sparkType == TimestampNTZType ||
!DecimalType.is32BitDecimalType(sparkType);
boolean needsRebase = logicalTypeAnnotation instanceof DateLogicalTypeAnnotation &&
!"CORRECTED".equals(datetimeRebaseMode);
yield !needsUpcast && !needsRebase && !needsDecimalScaleRebase(sparkType);
!"CORRECTED".equals(datetimeRebaseMode);
isSupported = !needsUpcast && !needsRebase;
break;
}
case INT64 -> {
case INT64: {
boolean needsUpcast = !DecimalType.is64BitDecimalType(sparkType) ||
updaterFactory.isTimestampTypeMatched(TimeUnit.MILLIS);
boolean needsRebase = updaterFactory.isTimestampTypeMatched(TimeUnit.MICROS) &&
!"CORRECTED".equals(datetimeRebaseMode);
yield !needsUpcast && !needsRebase && !needsDecimalScaleRebase(sparkType);
!"CORRECTED".equals(datetimeRebaseMode);
isSupported = !needsUpcast && !needsRebase;
break;
}
case FLOAT -> sparkType == FloatType;
case DOUBLE, BINARY -> !needsDecimalScaleRebase(sparkType);
default -> false;
};
case FLOAT:
isSupported = sparkType == FloatType;
break;
case DOUBLE:
case BINARY:
isSupported = true;
break;
}
return isSupported;
}

/**
* Returns whether the Parquet type of this column and the given spark type are two decimal types
* with different scales.
*/
private boolean needsDecimalScaleRebase(DataType sparkType) {
LogicalTypeAnnotation typeAnnotation = descriptor.getPrimitiveType().getLogicalTypeAnnotation();
if (!(typeAnnotation instanceof DecimalLogicalTypeAnnotation)) return false;
if (!(sparkType instanceof DecimalType)) return false;
DecimalLogicalTypeAnnotation parquetDecimal = (DecimalLogicalTypeAnnotation) typeAnnotation;
DecimalType sparkDecimal = (DecimalType) sparkType;
return parquetDecimal.getScale() != sparkDecimal.scale();
}

/**
* Reads `total` rows from this columnReader into column.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ import org.apache.spark.sql.{DataFrame, QueryTest, Row}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.datasources.SchemaColumnConvertNotSupportedException
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf}
import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types._

Expand All @@ -50,8 +51,10 @@ class ParquetTypeWideningSuite
toType: DataType,
expectError: => Boolean): Unit = {
val timestampRebaseModes = toType match {
case _: TimestampNTZType | _: DateType => Seq("CORRECTED", "LEGACY")
case _ => Seq("CORRECTED")
case _: TimestampNTZType | _: DateType =>
Seq(LegacyBehaviorPolicy.CORRECTED, LegacyBehaviorPolicy.LEGACY)
case _ =>
Seq(LegacyBehaviorPolicy.CORRECTED)
}
for {
dictionaryEnabled <- Seq(true, false)
Expand All @@ -72,8 +75,10 @@ class ParquetTypeWideningSuite
assert(
exception.getCause.getCause
.isInstanceOf[SchemaColumnConvertNotSupportedException] ||
exception.getCause.getCause
.isInstanceOf[org.apache.parquet.io.ParquetDecodingException])
exception.getCause.getCause
.isInstanceOf[org.apache.parquet.io.ParquetDecodingException] ||
exception.getCause.getMessage.contains(
"Unable to create Parquet converter for data type"))
} else {
checkAnswer(readParquetFiles(dir, toType), expected.select($"a".cast(toType)))
}
Expand Down Expand Up @@ -102,12 +107,13 @@ class ParquetTypeWideningSuite
values: Seq[String],
dataType: DataType,
dictionaryEnabled: Boolean,
timestampRebaseMode: String = "CORRECTED"): DataFrame = {
timestampRebaseMode: LegacyBehaviorPolicy.Value = LegacyBehaviorPolicy.CORRECTED)
: DataFrame = {
val repeatedValues = List.fill(if (dictionaryEnabled) 10 else 1)(values).flatten
val df = repeatedValues.toDF("a").select(col("a").cast(dataType))
withSQLConf(
ParquetOutputFormat.ENABLE_DICTIONARY -> dictionaryEnabled.toString,
SQLConf.PARQUET_REBASE_MODE_IN_WRITE.key -> timestampRebaseMode) {
SQLConf.PARQUET_REBASE_MODE_IN_WRITE.key -> timestampRebaseMode.toString) {
df.write.mode("overwrite").parquet(dir.getAbsolutePath)
}

Expand Down Expand Up @@ -160,13 +166,31 @@ class ParquetTypeWideningSuite
(Seq("1", "2", Int.MinValue.toString), LongType, IntegerType),
(Seq("1.23", "10.34"), DoubleType, FloatType),
(Seq("1.23", "10.34"), FloatType, LongType),
(Seq("2020-01-01", "2020-01-02", "1312-02-27"), TimestampNTZType, DateType)
(Seq("1.23", "10.34"), LongType, DateType),
(Seq("1.23", "10.34"), IntegerType, TimestampType),
(Seq("1.23", "10.34"), IntegerType, TimestampNTZType),
(Seq("2020-01-01", "2020-01-02", "1312-02-27"), DateType, TimestampType)
)
}
test(s"unsupported parquet conversion $fromType -> $toType") {
checkAllParquetReaders(values, fromType, toType, expectError = true)
}

for {
(values: Seq[String], fromType: DataType, toType: DataType) <- Seq(
(Seq("2020-01-01", "2020-01-02", "1312-02-27"), TimestampType, DateType),
(Seq("2020-01-01", "2020-01-02", "1312-02-27"), TimestampNTZType, DateType))
outputTimestampType <- ParquetOutputTimestampType.values
}
test(s"unsupported parquet timestamp conversion $fromType ($outputTimestampType) -> $toType") {
withSQLConf(
SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> outputTimestampType.toString,
SQLConf.PARQUET_INT96_REBASE_MODE_IN_WRITE.key -> LegacyBehaviorPolicy.CORRECTED.toString
) {
checkAllParquetReaders(values, fromType, toType, expectError = true)
}
}

for {
(fromPrecision, toPrecision) <-
// Test widening and narrowing precision between the same and different decimal physical
Expand Down

0 comments on commit dd1975c

Please sign in to comment.