diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java index 729812252b102..8fce52051d7dc 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java @@ -89,6 +89,8 @@ public ParquetVectorUpdater getUpdater(ColumnDescriptor descriptor, DataType spa return new ByteUpdater(); } else if (sparkType == DataTypes.ShortType) { return new ShortUpdater(); + } else if (sparkType == DataTypes.DoubleType) { + return new IntegerToDoubleUpdater(); } else if (sparkType == DataTypes.DateType) { if ("CORRECTED".equals(datetimeRebaseMode)) { return new IntegerUpdater(); @@ -331,6 +333,41 @@ public void decodeSingleDictionaryId( } } + static class IntegerToDoubleUpdater implements ParquetVectorUpdater { + @Override + public void readValues( + int total, + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + for (int i = 0; i < total; ++i) { + values.putDouble(offset + i, valuesReader.readInteger()); + } + } + + @Override + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipIntegers(total); + } + + @Override + public void readValue( + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + values.putDouble(offset, valuesReader.readInteger()); + } + + @Override + public void decodeSingleDictionaryId( + int offset, + WritableColumnVector values, + WritableColumnVector dictionaryIds, + Dictionary dictionary) { + values.putDouble(offset, dictionary.decodeToInt(dictionaryIds.getDictId(offset))); + } + } + static class DateToTimestampNTZUpdater implements ParquetVectorUpdater { @Override public void readValues( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala index dd72ce6b31961..b2222f4297e90 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala @@ -318,6 +318,11 @@ private[parquet] class ParquetRowConverter( override def addInt(value: Int): Unit = this.updater.setLong(value) } + case DoubleType if parquetType.asPrimitiveType().getPrimitiveTypeName == INT32 => + new ParquetPrimitiveConverter(updater) { + override def addInt(value: Int): Unit = + this.updater.setDouble(value) + } case DoubleType if parquetType.asPrimitiveType().getPrimitiveTypeName == FLOAT => new ParquetPrimitiveConverter(updater) { override def addFloat(value: Float): Unit = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypeWideningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypeWideningSuite.scala index b862c3385592b..811907e39c202 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypeWideningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypeWideningSuite.scala @@ -145,8 +145,11 @@ class ParquetTypeWideningSuite // Int->Short isn't a widening conversion but Parquet stores both as INT32 so it just works. (Seq("1", "2", Short.MinValue.toString), IntegerType, ShortType), (Seq("1", "2", Int.MinValue.toString), IntegerType, LongType), - (Seq("2020-01-01", "2020-01-02", "1312-02-27"), DateType, TimestampNTZType), - (Seq("1.23", "10.34"), FloatType, DoubleType)) + (Seq("1", "2", Short.MinValue.toString), ShortType, DoubleType), + (Seq("1", "2", Int.MinValue.toString), IntegerType, DoubleType), + (Seq("1.23", "10.34"), FloatType, DoubleType), + (Seq("2020-01-01", "2020-01-02", "1312-02-27"), DateType, TimestampNTZType) + ) } test(s"parquet widening conversion $fromType -> $toType") { checkAllParquetReaders(values, fromType, toType, expectError = false) @@ -155,9 +158,10 @@ class ParquetTypeWideningSuite for { (values: Seq[String], fromType: DataType, toType: DataType) <- Seq( (Seq("1", "2", Int.MinValue.toString), LongType, IntegerType), - // Test different timestamp types - (Seq("2020-01-01", "2020-01-02", "1312-02-27"), TimestampNTZType, DateType), - (Seq("1.23", "10.34"), DoubleType, FloatType)) + (Seq("1.23", "10.34"), DoubleType, FloatType), + (Seq("1.23", "10.34"), FloatType, LongType), + (Seq("2020-01-01", "2020-01-02", "1312-02-27"), TimestampNTZType, DateType) + ) } test(s"unsupported parquet conversion $fromType -> $toType") { checkAllParquetReaders(values, fromType, toType, expectError = true)