diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java index 1ed6c4329ebd0..0ece8fefac09a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java @@ -81,6 +81,10 @@ public ParquetVectorUpdater getUpdater(ColumnDescriptor descriptor, DataType spa // For unsigned int32, it stores as plain signed int32 in Parquet when dictionary // fallbacks. We read them as long values. return new UnsignedIntegerUpdater(); + } else if (sparkType == DataTypes.LongType || canReadAsLongDecimal(descriptor, sparkType)) { + return new IntegerToLongUpdater(); + } else if (canReadAsBinaryDecimal(descriptor, sparkType)) { + return new IntegerToBinaryUpdater(); } else if (sparkType == DataTypes.ByteType) { return new ByteUpdater(); } else if (sparkType == DataTypes.ShortType) { @@ -92,6 +96,13 @@ public ParquetVectorUpdater getUpdater(ColumnDescriptor descriptor, DataType spa boolean failIfRebase = "EXCEPTION".equals(datetimeRebaseMode); return new IntegerWithRebaseUpdater(failIfRebase); } + } else if (sparkType == DataTypes.TimestampNTZType) { + if ("CORRECTED".equals(datetimeRebaseMode)) { + return new DateToTimestampNTZUpdater(); + } else { + boolean failIfRebase = "EXCEPTION".equals(datetimeRebaseMode); + return new DateToTimestampNTZWithRebaseUpdater(failIfRebase); + } } else if (sparkType instanceof YearMonthIntervalType) { return new IntegerUpdater(); } @@ -104,6 +115,8 @@ public ParquetVectorUpdater getUpdater(ColumnDescriptor descriptor, DataType spa } else { return new LongUpdater(); } + } else if (canReadAsBinaryDecimal(descriptor, sparkType)) { + return new LongToBinaryUpdater(); } else if (isLongDecimal(sparkType) && isUnsignedIntTypeMatched(64)) { // In `ParquetToSparkSchemaConverter`, we map parquet UINT64 to our Decimal(20, 0). // For unsigned int64, it stores as plain signed int64 in Parquet when dictionary @@ -134,6 +147,8 @@ public ParquetVectorUpdater getUpdater(ColumnDescriptor descriptor, DataType spa case FLOAT -> { if (sparkType == DataTypes.FloatType) { return new FloatUpdater(); + } else if (sparkType == DataTypes.DoubleType) { + return new FloatToDoubleUpdater(); } } case DOUBLE -> { @@ -281,6 +296,121 @@ public void decodeSingleDictionaryId( } } + static class IntegerToLongUpdater implements ParquetVectorUpdater { + @Override + public void readValues( + int total, + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + for (int i = 0; i < total; ++i) { + values.putLong(offset + i, valuesReader.readInteger()); + } + } + + @Override + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipIntegers(total); + } + + @Override + public void readValue( + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + values.putLong(offset, valuesReader.readInteger()); + } + + @Override + public void decodeSingleDictionaryId( + int offset, + WritableColumnVector values, + WritableColumnVector dictionaryIds, + Dictionary dictionary) { + values.putLong(offset, dictionary.decodeToInt(dictionaryIds.getDictId(offset))); + } + } + + static class DateToTimestampNTZUpdater implements ParquetVectorUpdater { + @Override + public void readValues( + int total, + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + for (int i = 0; i < total; ++i) { + values.putLong(offset + i, DateTimeUtils.daysToMicros(valuesReader.readInteger(), ZoneOffset.UTC)); + } + } + + @Override + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipIntegers(total); + } + + @Override + public void readValue( + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + values.putLong(offset, DateTimeUtils.daysToMicros(valuesReader.readInteger(), ZoneOffset.UTC)); + } + + @Override + public void decodeSingleDictionaryId( + int offset, + WritableColumnVector values, + WritableColumnVector dictionaryIds, + Dictionary dictionary) { + int days = dictionary.decodeToInt(dictionaryIds.getDictId(offset)); + values.putLong(offset, DateTimeUtils.daysToMicros(days, ZoneOffset.UTC)); + } + } + + private static class DateToTimestampNTZWithRebaseUpdater implements ParquetVectorUpdater { + private final boolean failIfRebase; + + DateToTimestampNTZWithRebaseUpdater(boolean failIfRebase) { + this.failIfRebase = failIfRebase; + } + + @Override + public void readValues( + int total, + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + for (int i = 0; i < total; ++i) { + int rebasedDays = rebaseDays(valuesReader.readInteger(), failIfRebase); + values.putLong(offset + i, DateTimeUtils.daysToMicros(rebasedDays, ZoneOffset.UTC)); + } + } + + @Override + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipIntegers(total); + } + + @Override + public void readValue( + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + int rebasedDays = rebaseDays(valuesReader.readInteger(), failIfRebase); + values.putLong(offset, DateTimeUtils.daysToMicros(rebasedDays, ZoneOffset.UTC)); + } + + @Override + public void decodeSingleDictionaryId( + int offset, + WritableColumnVector values, + WritableColumnVector dictionaryIds, + Dictionary dictionary) { + int rebasedDays = rebaseDays(dictionary.decodeToInt(dictionaryIds.getDictId(offset)), failIfRebase); + values.putLong(offset, DateTimeUtils.daysToMicros(rebasedDays, ZoneOffset.UTC)); + } + } + private static class UnsignedIntegerUpdater implements ParquetVectorUpdater { @Override public void readValues( @@ -684,6 +814,41 @@ public void decodeSingleDictionaryId( } } + static class FloatToDoubleUpdater implements ParquetVectorUpdater { + @Override + public void readValues( + int total, + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + for (int i = 0; i < total; ++i) { + values.putDouble(offset + i, valuesReader.readFloat()); + } + } + + @Override + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipFloats(total); + } + + @Override + public void readValue( + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + values.putDouble(offset, valuesReader.readFloat()); + } + + @Override + public void decodeSingleDictionaryId( + int offset, + WritableColumnVector values, + WritableColumnVector dictionaryIds, + Dictionary dictionary) { + values.putDouble(offset, dictionary.decodeToFloat(dictionaryIds.getDictId(offset))); + } + } + private static class DoubleUpdater implements ParquetVectorUpdater { @Override public void readValues( @@ -751,6 +916,82 @@ public void decodeSingleDictionaryId( } } + private static class IntegerToBinaryUpdater implements ParquetVectorUpdater { + + @Override + public void readValues( + int total, + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + for (int i = 0; i < total; i++) { + readValue(offset + i, values, valuesReader); + } + } + + @Override + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipIntegers(total); + } + + @Override + public void readValue( + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + BigInteger value = BigInteger.valueOf(valuesReader.readInteger()); + values.putByteArray(offset, value.toByteArray()); + } + + @Override + public void decodeSingleDictionaryId( + int offset, + WritableColumnVector values, + WritableColumnVector dictionaryIds, + Dictionary dictionary) { + BigInteger value = BigInteger.valueOf(dictionary.decodeToInt(dictionaryIds.getDictId(offset))); + values.putByteArray(offset, value.toByteArray()); + } + } + + private static class LongToBinaryUpdater implements ParquetVectorUpdater { + + @Override + public void readValues( + int total, + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + for (int i = 0; i < total; i++) { + readValue(offset + i, values, valuesReader); + } + } + + @Override + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipLongs(total); + } + + @Override + public void readValue( + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + BigInteger value = BigInteger.valueOf(valuesReader.readLong()); + values.putByteArray(offset, value.toByteArray()); + } + + @Override + public void decodeSingleDictionaryId( + int offset, + WritableColumnVector values, + WritableColumnVector dictionaryIds, + Dictionary dictionary) { + BigInteger value = BigInteger.valueOf(dictionary.decodeToLong(dictionaryIds.getDictId(offset))); + values.putByteArray(offset, value.toByteArray()); + } + } + private static class BinaryToSQLTimestampUpdater implements ParquetVectorUpdater { @Override public void readValues( diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index 04fbe716ad92f..33b9412c37663 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -38,10 +38,13 @@ import org.apache.parquet.schema.PrimitiveType; import org.apache.spark.sql.execution.vectorized.WritableColumnVector; +import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.DecimalType; import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.BOOLEAN; import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.INT64; +import static org.apache.spark.sql.types.DataTypes.*; /** * Decoder to return values from a single column. @@ -140,23 +143,42 @@ public VectorizedColumnReader( this.writerVersion = writerVersion; } - private boolean isLazyDecodingSupported(PrimitiveType.PrimitiveTypeName typeName) { + private boolean isLazyDecodingSupported(PrimitiveType.PrimitiveTypeName typeName, + DataType sparkType) { + // Don't use lazy dictionary decoding if the column needs extra processing: upcasting or date / + // decimal scale rebasing. return switch (typeName) { - case INT32 -> - !(logicalTypeAnnotation instanceof DateLogicalTypeAnnotation) || - "CORRECTED".equals(datetimeRebaseMode); + case INT32 -> { + boolean needsUpcast = sparkType == LongType || sparkType == TimestampNTZType || + !DecimalType.is32BitDecimalType(sparkType); + boolean needsRebase = logicalTypeAnnotation instanceof DateLogicalTypeAnnotation && !"CORRECTED".equals(datetimeRebaseMode); + yield !needsUpcast && !needsRebase && !needsDecimalScaleRebase(sparkType); + } case INT64 -> { - if (updaterFactory.isTimestampTypeMatched(TimeUnit.MICROS)) { - yield "CORRECTED".equals(datetimeRebaseMode); - } else { - yield !updaterFactory.isTimestampTypeMatched(TimeUnit.MILLIS); - } + boolean needsUpcast = !DecimalType.is64BitDecimalType(sparkType) || + updaterFactory.isTimestampTypeMatched(TimeUnit.MILLIS); + boolean needsRebase = updaterFactory.isTimestampTypeMatched(TimeUnit.MICROS) && !"CORRECTED".equals(datetimeRebaseMode); + yield !needsUpcast && !needsRebase && !needsDecimalScaleRebase(sparkType); } - case FLOAT, DOUBLE, BINARY -> true; + case FLOAT -> sparkType == FloatType; + case DOUBLE, BINARY -> !needsDecimalScaleRebase(sparkType); default -> false; }; } + /** + * Returns whether the Parquet type of this column and the given spark type are two decimal types + * with different scales. + */ + private boolean needsDecimalScaleRebase(DataType sparkType) { + LogicalTypeAnnotation typeAnnotation = descriptor.getPrimitiveType().getLogicalTypeAnnotation(); + if (!(typeAnnotation instanceof DecimalLogicalTypeAnnotation)) return false; + if (!(sparkType instanceof DecimalType)) return false; + DecimalLogicalTypeAnnotation parquetDecimal = (DecimalLogicalTypeAnnotation) typeAnnotation; + DecimalType sparkDecimal = (DecimalType) sparkType; + return parquetDecimal.getScale() != sparkDecimal.scale(); +} + /** * Reads `total` rows from this columnReader into column. */ @@ -205,7 +227,7 @@ void readBatch( // TIMESTAMP_MILLIS encoded as INT64 can't be lazily decoded as we need to post process // the values to add microseconds precision. if (column.hasDictionary() || (startRowId == pageFirstRowIndex && - isLazyDecodingSupported(typeName))) { + isLazyDecodingSupported(typeName, column.dataType()))) { // Column vector supports lazy decoding of dictionary values so just set the dictionary. // We can't do this if startRowId is not the first row index in the page AND the column // doesn't have a dictionary (i.e. some non-dictionary encoded values have already been diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala index b3be89085014e..dd72ce6b31961 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala @@ -29,7 +29,7 @@ import org.apache.parquet.io.ColumnIOFactory import org.apache.parquet.io.api.{Binary, Converter, GroupConverter, PrimitiveConverter} import org.apache.parquet.schema.{GroupType, Type, Types} import org.apache.parquet.schema.LogicalTypeAnnotation._ -import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.{BINARY, FIXED_LEN_BYTE_ARRAY, INT32, INT64, INT96} +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.{BINARY, FIXED_LEN_BYTE_ARRAY, FLOAT, INT32, INT64, INT96} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow @@ -313,6 +313,16 @@ private[parquet] class ParquetRowConverter( override def addInt(value: Int): Unit = this.updater.setLong(Integer.toUnsignedLong(value)) } + case LongType if parquetType.asPrimitiveType().getPrimitiveTypeName == INT32 => + new ParquetPrimitiveConverter(updater) { + override def addInt(value: Int): Unit = + this.updater.setLong(value) + } + case DoubleType if parquetType.asPrimitiveType().getPrimitiveTypeName == FLOAT => + new ParquetPrimitiveConverter(updater) { + override def addFloat(value: Float): Unit = + this.updater.setDouble(value) + } case BooleanType | IntegerType | LongType | FloatType | DoubleType | BinaryType | _: AnsiIntervalType => new ParquetPrimitiveConverter(updater) @@ -438,6 +448,16 @@ private[parquet] class ParquetRowConverter( } } + // Allow upcasting INT32 date to timestampNTZ. + case TimestampNTZType if schemaConverter.isTimestampNTZEnabled() && + parquetType.asPrimitiveType().getPrimitiveTypeName == INT32 && + parquetType.getLogicalTypeAnnotation.isInstanceOf[DateLogicalTypeAnnotation] => + new ParquetPrimitiveConverter(updater) { + override def addInt(value: Int): Unit = { + this.updater.set(DateTimeUtils.daysToMicros(dateRebaseFunc(value), ZoneOffset.UTC)) + } + } + case DateType => new ParquetPrimitiveConverter(updater) { override def addInt(value: Int): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 374a8a8078edc..6319d47ffb78b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -1070,17 +1070,6 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession } } - test("SPARK-35640: int as long should throw schema incompatible error") { - val data = (1 to 4).map(i => Tuple1(i)) - val readSchema = StructType(Seq(StructField("_1", DataTypes.LongType))) - - withParquetFile(data) { path => - val errMsg = intercept[Exception](spark.read.schema(readSchema).parquet(path).collect()) - .getMessage - assert(errMsg.contains("Parquet column cannot be converted in file")) - } - } - test("write metadata") { val hadoopConf = spark.sessionState.newHadoopConf() withTempPath { file => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 43103db522bac..41019c83f7896 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -1098,19 +1098,13 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS test("row group skipping doesn't overflow when reading into larger type") { withTempPath { path => Seq(0).toDF("a").write.parquet(path.toString) - // The vectorized and non-vectorized readers will produce different exceptions, we don't need - // to test both as this covers row group skipping. - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true") { - // Reading integer 'a' as a long isn't supported. Check that an exception is raised instead - // of incorrectly skipping the single row group and producing incorrect results. - val exception = intercept[SparkException] { + withAllParquetReaders { + val result = spark.read .schema("a LONG") .parquet(path.toString) .where(s"a < ${Long.MaxValue}") - .collect() - } - assert(exception.getCause.getCause.isInstanceOf[SchemaColumnConvertNotSupportedException]) + checkAnswer(result, Row(0)) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypeWideningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypeWideningSuite.scala new file mode 100644 index 0000000000000..e612ada038f38 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypeWideningSuite.scala @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet + +import java.io.File + +import org.apache.hadoop.fs.Path +import org.apache.parquet.format.converter.ParquetMetadataConverter +import org.apache.parquet.hadoop.{ParquetFileReader, ParquetOutputFormat} +import org.apache.spark.SparkException +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.execution.datasources.SchemaColumnConvertNotSupportedException +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.{DataFrame, QueryTest, Row} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types._ + +class ParquetTypeWideningSuite + extends QueryTest + with ParquetTest + with SharedSparkSession + with AdaptiveSparkPlanHelper { + + import testImplicits._ + + /** + * Write a Parquet file with the given values stored using type `fromType` and read it back + * using type `toType` with each Parquet reader. If `expectError` returns true, check that an + * error is thrown during the read. Otherwise check that the data read matches the data written. + */ + private def checkAllParquetReaders( + values: Seq[String], + fromType: DataType, + toType: DataType, + expectError: => Boolean): Unit = { + val timestampRebaseModes = toType match { + case _: TimestampNTZType | _: DateType => Seq("CORRECTED", "LEGACY") + case _ => Seq("CORRECTED") + } + for { + dictionaryEnabled <- Seq(true, false) + timestampRebaseMode <- timestampRebaseModes + } + withClue( + s"with dictionary encoding '$dictionaryEnabled' with timestamp rebase mode " + + s"'$timestampRebaseMode''") { + withAllParquetWriters { + withTempDir { dir => + val expected = + writeParquetFiles(dir, values, fromType, dictionaryEnabled, timestampRebaseMode) + withAllParquetReaders { + if (expectError) { + val exception = intercept[SparkException] { + readParquetFiles(dir, toType).collect() + } + assert( + exception.getCause.getCause + .isInstanceOf[SchemaColumnConvertNotSupportedException] || + exception.getCause.getCause + .isInstanceOf[org.apache.parquet.io.ParquetDecodingException]) + } else { + checkAnswer(readParquetFiles(dir, toType), expected.select($"a".cast(toType))) + } + } + } + } + } + } + + /** + * Reads all parquet files in the given directory using the given type. + */ + private def readParquetFiles(dir: File, dataType: DataType): DataFrame = { + spark.read.schema(s"a ${dataType.sql}").parquet(dir.getAbsolutePath) + } + + /** + * Writes values to a parquet file in the given directory using the given type and returns a + * DataFrame corresponding to the data written. If dictionaryEnabled is true, the columns will + * be dictionary encoded. Each provided value is repeated 10 times to allow dictionary encoding + * to be used. timestampRebaseMode can be either "CORRECTED" or "LEGACY", see + * [[SQLConf.PARQUET_REBASE_MODE_IN_WRITE]] + */ + private def writeParquetFiles( + dir: File, + values: Seq[String], + dataType: DataType, + dictionaryEnabled: Boolean, + timestampRebaseMode: String = "CORRECTED"): DataFrame = { + val repeatedValues = List.fill(if (dictionaryEnabled) 10 else 1)(values).flatten + val df = repeatedValues.toDF("a").select(col("a").cast(dataType)) + withSQLConf( + ParquetOutputFormat.ENABLE_DICTIONARY -> dictionaryEnabled.toString, + SQLConf.PARQUET_REBASE_MODE_IN_WRITE.key -> timestampRebaseMode) { + df.write.mode("overwrite").parquet(dir.getAbsolutePath) + } + + // Decimals stored as byte arrays (precision > 18) are not dictionary encoded. + if (dictionaryEnabled && !DecimalType.isByteArrayDecimalType(dataType)) { + assertAllParquetFilesDictionaryEncoded(dir) + } + df + } + + /** + * Asserts that all parquet files in the given directory have all their columns dictionary + * encoded. + */ + private def assertAllParquetFilesDictionaryEncoded(dir: File): Unit = { + dir.listFiles(_.getName.endsWith(".parquet")).foreach { file => + val parquetMetadata = ParquetFileReader.readFooter( + spark.sessionState.newHadoopConf(), + new Path(dir.toString, file.getName), + ParquetMetadataConverter.NO_FILTER) + parquetMetadata.getBlocks.forEach { block => + block.getColumns.forEach { col => + assert( + col.hasDictionaryPage, + "This test covers dictionary encoding but column " + + s"'${col.getPath.toDotString}' in the test data is not dictionary encoded.") + } + } + } + } + + for { + case (values: Seq[String], fromType: DataType, toType: DataType) <- Seq( + (Seq("1", "2", Short.MinValue.toString), ShortType, IntegerType), + // Int->Short isn't a widening conversion but Parquet stores both as INT32 so it just works. + (Seq("1", "2", Short.MinValue.toString), IntegerType, ShortType), + (Seq("1", "2", Int.MinValue.toString), IntegerType, LongType), + (Seq("2020-01-01", "2020-01-02", "1312-02-27"), DateType, TimestampNTZType), + (Seq("1.23", "10.34"), FloatType, DoubleType)) + } + test(s"parquet widening conversion $fromType -> $toType") { + checkAllParquetReaders(values, fromType, toType, expectError = false) + } + + for { + case (values: Seq[String], fromType: DataType, toType: DataType) <- Seq( + (Seq("1", "2", Int.MinValue.toString), LongType, IntegerType), + // Test different timestamp types + (Seq("2020-01-01", "2020-01-02", "1312-02-27"), TimestampNTZType, DateType), + (Seq("1.23", "10.34"), DoubleType, FloatType)) + } + test(s"unsupported parquet conversion $fromType -> $toType") { + checkAllParquetReaders(values, fromType, toType, expectError = true) + } + + for { + (fromPrecision, toPrecision) <- + // Test widening and narrowing precision between the same and different decimal physical + // parquet types: + // - INT32: precisions 5, 7 + // - INT64: precisions 10, 12 + // - FIXED_LEN_BYTE_ARRAY: precisions 20, 22 + Seq(5 -> 7, 5 -> 10, 5 -> 20, 10 -> 12, 10 -> 20, 20 -> 22) ++ + Seq(7 -> 5, 10 -> 5, 20 -> 5, 12 -> 10, 20 -> 10, 22 -> 20) + } + test( + s"parquet decimal precision change Decimal($fromPrecision, 2) -> Decimal($toPrecision, 2)") { + checkAllParquetReaders( + values = Seq("1.23", "10.34"), + fromType = DecimalType(fromPrecision, 2), + toType = DecimalType(toPrecision, 2), + expectError = fromPrecision > toPrecision && + // parquet-mr allows reading decimals into a smaller precision decimal type without + // checking for overflows. See test below. + spark.conf.get(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key).toBoolean) + } + + test("parquet decimal type change Decimal(5, 2) -> Decimal(3, 2) overflows with parquet-mr") { + withTempDir { dir => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + writeParquetFiles( + dir, + values = Seq("123.45", "999.99"), + DecimalType(5, 2), + dictionaryEnabled = false) + checkAnswer(readParquetFiles(dir, DecimalType(3, 2)), Row(null) :: Row(null) :: Nil) + } + } + } + + test("parquet decimal type change IntegerType -> ShortType overflows") { + withTempDir { dir => + withAllParquetReaders { + // Int & Short are both stored as INT32 in Parquet but Int.MinValue will overflow when + // reading as Short in Spark. + val overflowValue = Short.MaxValue.toInt + 1 + writeParquetFiles( + dir, + Seq(overflowValue.toString), + IntegerType, + dictionaryEnabled = false) + checkAnswer(readParquetFiles(dir, ShortType), Row(Short.MinValue)) + } + } + } +}