From c3cf357bb8c3ac7ac7052bc9749ee5f44e2dc5c4 Mon Sep 17 00:00:00 2001 From: Raza Jafri Date: Mon, 21 Mar 2022 10:33:40 -0700 Subject: [PATCH] Add UDT support to ParquetCachedBatchSerializer (CPU) (#4955) * Add support for UDT Signed-off-by: Raza Jafri * add test and checks to fallback to CPU Signed-off-by: Raza Jafri * some refactoring to simplify code Signed-off-by: Raza Jafri * addressed review comments Signed-off-by: Raza Jafri * removed the mapping Signed-off-by: Raza Jafri Co-authored-by: Raza Jafri --- .../src/main/python/cache_test.py | 16 +- .../spark/ParquetCachedBatchSerializer.scala | 20 +- .../shims/ParquetCachedBatchSerializer.scala | 280 ++++++------------ .../spark/sql/rapids/PCBSSchemaHelper.scala | 136 +++++++++ .../com/nvidia/spark/rapids/TypeChecks.scala | 2 +- 5 files changed, 258 insertions(+), 196 deletions(-) create mode 100644 sql-plugin/src/main/311+-all/scala/org/apache/spark/sql/rapids/PCBSSchemaHelper.scala diff --git a/integration_tests/src/main/python/cache_test.py b/integration_tests/src/main/python/cache_test.py index d718b619909..8602ca1c030 100644 --- a/integration_tests/src/main/python/cache_test.py +++ b/integration_tests/src/main/python/cache_test.py @@ -20,6 +20,8 @@ from spark_session import with_cpu_session, with_gpu_session, is_before_spark_330 from join_test import create_df from marks import incompat, allow_non_gpu, ignore_order +import pyspark.mllib.linalg as mllib +import pyspark.ml.linalg as ml enable_vectorized_confs = [{"spark.sql.inMemoryColumnarStorage.enableVectorizedReader": "true"}, {"spark.sql.inMemoryColumnarStorage.enableVectorizedReader": "false"}] @@ -286,6 +288,19 @@ def helper(spark): assert_gpu_and_cpu_are_equal_collect(helper) +def test_cache_udt(): + def fun(spark): + df = spark.sparkContext.parallelize([ + (mllib.DenseVector([1, ]), ml.DenseVector([1, ])), + (mllib.SparseVector(1, [0, ], [1, ]), ml.SparseVector(1, [0, ], [1, ])) + ]).toDF(["mllib_v", "ml_v"]) + df.cache().count() + return df.selectExpr("mllib_v", "ml_v").collect() + cpu_result = with_cpu_session(fun) + gpu_result = with_gpu_session(fun) + # assert_gpu_and_cpu_are_equal_collect method doesn't handle UDT so we just write a single + # statement here to compare + assert cpu_result == gpu_result, "not equal" @pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Spark3.3.0') @pytest.mark.parametrize('enable_vectorized_conf', enable_vectorized_confs, ids=idfn) @@ -296,4 +311,3 @@ def test_func(spark): df.cache().count() return df.selectExpr("b", "a") assert_gpu_and_cpu_are_equal_collect(test_func, enable_vectorized_conf) - diff --git a/sql-plugin/src/main/311+-all/scala/com/nvidia/spark/ParquetCachedBatchSerializer.scala b/sql-plugin/src/main/311+-all/scala/com/nvidia/spark/ParquetCachedBatchSerializer.scala index b3ca7244694..962a083a1d9 100644 --- a/sql-plugin/src/main/311+-all/scala/com/nvidia/spark/ParquetCachedBatchSerializer.scala +++ b/sql-plugin/src/main/311+-all/scala/com/nvidia/spark/ParquetCachedBatchSerializer.scala @@ -16,7 +16,7 @@ package com.nvidia.spark -import com.nvidia.spark.rapids.{DataFromReplacementRule, GpuExec, RapidsConf, RapidsMeta, ShimLoader, SparkPlanMeta} +import com.nvidia.spark.rapids.{DataFromReplacementRule, ExecChecks, GpuExec, RapidsConf, RapidsMeta, ShimLoader, SparkPlanMeta} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -25,7 +25,7 @@ import org.apache.spark.sql.columnar.{CachedBatch, CachedBatchSerializer} import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} import org.apache.spark.sql.rapids.shims.GpuInMemoryTableScanExec -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.storage.StorageLevel @@ -46,6 +46,22 @@ class InMemoryTableScanMeta( extends SparkPlanMeta[InMemoryTableScanExec](imts, conf, parent, rule) { override def tagPlanForGpu(): Unit = { + def stringifyTypeAttributeMap(groupedByType: Map[DataType, Set[String]]): String = { + groupedByType.map { case (dataType, nameSet) => + dataType + " " + nameSet.mkString("[", ", ", "]") + }.mkString(", ") + } + + val supportedTypeSig = rule.getChecks.get.asInstanceOf[ExecChecks] + val unsupportedTypes: Map[DataType, Set[String]] = imts.relation.output + .filterNot(attr => supportedTypeSig.check.isSupportedByPlugin(attr.dataType)) + .groupBy(_.dataType) + .mapValues(_.map(_.name).toSet) + + val msgFormat = "unsupported data types in output: %s" + if (unsupportedTypes.nonEmpty) { + willNotWorkOnGpu(msgFormat.format(stringifyTypeAttributeMap(unsupportedTypes))) + } if (!imts.relation.cacheBuilder.serializer .isInstanceOf[com.nvidia.spark.ParquetCachedBatchSerializer]) { willNotWorkOnGpu("ParquetCachedBatchSerializer is not being used") diff --git a/sql-plugin/src/main/311+-all/scala/com/nvidia/spark/rapids/shims/ParquetCachedBatchSerializer.scala b/sql-plugin/src/main/311+-all/scala/com/nvidia/spark/rapids/shims/ParquetCachedBatchSerializer.scala index afcb0fe9792..2252c871009 100644 --- a/sql-plugin/src/main/311+-all/scala/com/nvidia/spark/rapids/shims/ParquetCachedBatchSerializer.scala +++ b/sql-plugin/src/main/311+-all/scala/com/nvidia/spark/rapids/shims/ParquetCachedBatchSerializer.scala @@ -19,7 +19,6 @@ package com.nvidia.spark.rapids.shims import java.io.{InputStream, IOException} import java.lang.reflect.Method import java.nio.ByteBuffer -import java.util.concurrent.atomic.AtomicLong import scala.collection.JavaConverters._ import scala.collection.mutable @@ -58,6 +57,7 @@ import org.apache.spark.sql.execution.datasources.parquet.rapids.shims.{ParquetR import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, WritableColumnVector} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy +import org.apache.spark.sql.rapids.PCBSSchemaHelper import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.storage.StorageLevel @@ -267,18 +267,8 @@ protected class ParquetCachedBatchSerializer extends GpuCachedBatchSerializer wi override def supportsColumnarOutput(schema: StructType): Boolean = schema.fields.forall { f => // only check spark b/c if we are on the GPU then we will be calling the gpu method regardless - isTypeSupportedByColumnarSparkParquetWriter(f.dataType) || f.dataType == DataTypes.NullType - } - - private def isTypeSupportedByColumnarSparkParquetWriter(dataType: DataType): Boolean = { - // Columnar writer in Spark only supports AtomicTypes ATM - dataType match { - case TimestampType | StringType | BooleanType | DateType | BinaryType | - DoubleType | FloatType | ByteType | IntegerType | LongType | ShortType => true - case _: DecimalType => true - case other if GpuTypeShims.isParquetColumnarWriterSupportedForType(other) => true - case _ => false - } + PCBSSchemaHelper.isTypeSupportedByColumnarSparkParquetWriter(f.dataType) || + f.dataType == DataTypes.NullType } def isSchemaSupportedByCudf(schema: Seq[Attribute]): Boolean = { @@ -294,24 +284,6 @@ protected class ParquetCachedBatchSerializer extends GpuCachedBatchSerializer wi } } - /** - * This method checks if the datatype passed is officially supported by parquet. - * - * Please refer to https://github.com/apache/parquet-format/blob/master/LogicalTypes.md to see - * the what types are supported by parquet - */ - def isTypeSupportedByParquet(dataType: DataType): Boolean = { - dataType match { - case CalendarIntervalType | NullType => false - case s: StructType => s.forall(field => isTypeSupportedByParquet(field.dataType)) - case ArrayType(elementType, _) => isTypeSupportedByParquet(elementType) - case MapType(keyType, valueType, _) => isTypeSupportedByParquet(keyType) && - isTypeSupportedByParquet(valueType) - case d: DecimalType if d.scale < 0 => false - case _ => true - } - } - /** * Convert an `RDD[ColumnarBatch]` into an `RDD[CachedBatch]` in preparation for caching the data. * This method uses Parquet Writer on the GPU to write the cached batch @@ -645,7 +617,8 @@ protected class ParquetCachedBatchSerializer extends GpuCachedBatchSerializer wi dataType match { case s@StructType(_) => val listBuffer = new ListBuffer[InternalRow]() - val supportedSchema = mapping(dataType).asInstanceOf[StructType] + val supportedSchema = + PCBSSchemaHelper.getSupportedDataType(dataType).asInstanceOf[StructType] arrayData.foreach(supportedSchema, (_, data) => { val structRow = handleStruct(data.asInstanceOf[InternalRow], s, s) @@ -750,7 +723,7 @@ protected class ParquetCachedBatchSerializer extends GpuCachedBatchSerializer wi withResource(ParquetFileReader.open(inputFile, options)) { parquetFileReader => val parquetSchema = parquetFileReader.getFooter.getFileMetaData.getSchema val hasUnsupportedType = origCacheSchema.exists { field => - !isTypeSupportedByParquet(field.dataType) + !PCBSSchemaHelper.isTypeSupportedByParquet(field.dataType) } val unsafeRows = new ArrayBuffer[InternalRow] @@ -808,55 +781,48 @@ protected class ParquetCachedBatchSerializer extends GpuCachedBatchSerializer wi newRow: InternalRow): Unit = { schema.indices.foreach { index => val dataType = schema(index).dataType - if (mapping.contains(dataType) || dataType == CalendarIntervalType || - dataType == NullType || - (dataType.isInstanceOf[DecimalType] - && dataType.asInstanceOf[DecimalType].scale < 0)) { - if (row.isNullAt(index)) { - newRow.setNullAt(index) - } else { - dataType match { - case s@StructType(_) => - val supportedSchema = mapping(dataType) - .asInstanceOf[StructType] - val structRow = - handleStruct(row.getStruct(index, supportedSchema.size), s, s) - newRow.update(index, structRow) - - case a@ArrayType(_, _) => - val arrayData = row.getArray(index) - newRow.update(index, handleArray(a.elementType, arrayData)) - - case MapType(keyType, valueType, _) => - val mapData = row.getMap(index) - newRow.update(index, handleMap(keyType, valueType, mapData)) - - case CalendarIntervalType => - val interval = handleInterval(row, index) - if (interval == null) { - newRow.setNullAt(index) - } else { - newRow.setInterval(index, interval) - } - case d: DecimalType => - if (row.isNullAt(index)) { - newRow.setDecimal(index, null, d.precision) + if (row.isNullAt(index)) { + newRow.setNullAt(index) + } else { + dataType match { + case s@StructType(_) => + val supportedSchema = + PCBSSchemaHelper.getSupportedDataType(dataType).asInstanceOf[StructType] + val structRow = + handleStruct(row.getStruct(index, supportedSchema.size), s, s) + newRow.update(index, structRow) + + case a@ArrayType(_, _) => + val arrayData = row.getArray(index) + newRow.update(index, handleArray(a.elementType, arrayData)) + + case MapType(keyType, valueType, _) => + val mapData = row.getMap(index) + newRow.update(index, handleMap(keyType, valueType, mapData)) + + case CalendarIntervalType => + val interval = handleInterval(row, index) + if (interval == null) { + newRow.setNullAt(index) + } else { + newRow.setInterval(index, interval) + } + case d: DecimalType => + if (row.isNullAt(index)) { + newRow.setDecimal(index, null, d.precision) + } else { + val dec = if (d.precision <= Decimal.MAX_INT_DIGITS) { + Decimal(row.getInt(index).toLong, d.precision, d.scale) } else { - val dec = if (d.precision <= Decimal.MAX_INT_DIGITS) { - Decimal(row.getInt(index).toLong, d.precision, d.scale) - } else { - Decimal(row.getLong(index), d.precision, d.scale) - } - newRow.update(index, dec) + Decimal(row.getLong(index), d.precision, d.scale) } - case NullType => - newRow.setNullAt(index) - case _ => - newRow.update(index, row.get(index, dataType)) - } + newRow.update(index, dec) + } + case NullType => + newRow.setNullAt(index) + case _ => + newRow.update(index, row.get(index, dataType)) } - } else { - newRow.update(index, row.get(index, dataType)) } } } @@ -1073,11 +1039,6 @@ protected class ParquetCachedBatchSerializer extends GpuCachedBatchSerializer wi conf } - private val intervalStructType = new StructType() - .add("_days", IntegerType) - .add("_months", IntegerType) - .add("_ms", LongType) - def getBytesAllowedPerBatch(conf: SQLConf): Long = { val gpuBatchSize = RapidsConf.GPU_BATCH_SIZE_BYTES.get(conf) // we are rough estimating 0.5% as meta_data_size. we can do better estimation in future @@ -1129,7 +1090,7 @@ protected class ParquetCachedBatchSerializer extends GpuCachedBatchSerializer wi // is there a type that spark doesn't support by default in the schema? val hasUnsupportedType: Boolean = origCachedAttributes.exists { attribute => - !isTypeSupportedByParquet(attribute.dataType) + !PCBSSchemaHelper.isTypeSupportedByParquet(attribute.dataType) } def getIterator: Iterator[InternalRow] = { @@ -1171,53 +1132,47 @@ protected class ParquetCachedBatchSerializer extends GpuCachedBatchSerializer wi newRow: InternalRow): Unit = { schema.indices.foreach { index => val dataType = schema(index).dataType - if (mapping.contains(dataType) || dataType == CalendarIntervalType || - dataType == NullType || - (dataType.isInstanceOf[DecimalType] - && dataType.asInstanceOf[DecimalType].scale < 0)) { - if (row.isNullAt(index)) { - newRow.setNullAt(index) - } else { - dataType match { - case s@StructType(_) => - val newSchema = mapping(dataType).asInstanceOf[StructType] - val structRow = - handleStruct(row.getStruct(index, s.fields.length), s, newSchema) - newRow.update(index, structRow) - - case ArrayType(arrayDataType, _) => - val arrayData = row.getArray(index) - val newArrayData = handleArray(arrayDataType, arrayData) - newRow.update(index, newArrayData) - - case MapType(keyType, valueType, _) => - val mapData = row.getMap(index) - val map = handleMap(keyType, valueType, mapData) - newRow.update(index, map) - - case CalendarIntervalType => - val structData: InternalRow = handleInterval(row, index) - if (structData == null) { - newRow.setNullAt(index) - } else { - newRow.update(index, structData) - } - - case d: DecimalType if d.scale < 0 => - if (d.precision <= Decimal.MAX_INT_DIGITS) { - newRow.update(index, row.getDecimal(index, d.precision, d.scale) - .toUnscaledLong.toInt) - } else { - newRow.update(index, row.getDecimal(index, d.precision, d.scale) - .toUnscaledLong) - } - - case _ => - newRow.update(index, row.get(index, dataType)) - } - } + if (row.isNullAt(index)) { + newRow.setNullAt(index) } else { - newRow.update(index, row.get(index, dataType)) + dataType match { + case s@StructType(_) => + val newSchema = + PCBSSchemaHelper.getSupportedDataType(dataType).asInstanceOf[StructType] + val structRow = + handleStruct(row.getStruct(index, s.fields.length), s, newSchema) + newRow.update(index, structRow) + + case ArrayType(arrayDataType, _) => + val arrayData = row.getArray(index) + val newArrayData = handleArray(arrayDataType, arrayData) + newRow.update(index, newArrayData) + + case MapType(keyType, valueType, _) => + val mapData = row.getMap(index) + val map = handleMap(keyType, valueType, mapData) + newRow.update(index, map) + + case CalendarIntervalType => + val structData: InternalRow = handleInterval(row, index) + if (structData == null) { + newRow.setNullAt(index) + } else { + newRow.update(index, structData) + } + + case d: DecimalType if d.scale < 0 => + if (d.precision <= Decimal.MAX_INT_DIGITS) { + newRow.update(index, row.getDecimal(index, d.precision, d.scale) + .toUnscaledLong.toInt) + } else { + newRow.update(index, row.getDecimal(index, d.precision, d.scale) + .toUnscaledLong) + } + + case _ => + newRow.update(index, row.get(index, dataType)) + } } } } @@ -1327,46 +1282,6 @@ protected class ParquetCachedBatchSerializer extends GpuCachedBatchSerializer wi } - val mapping = new mutable.HashMap[DataType, DataType]() - - def getSupportedDataType(curId: AtomicLong, dataType: DataType): DataType = { - dataType match { - case CalendarIntervalType => - intervalStructType - case NullType => - ByteType - case s: StructType => - val newStructType = StructType( - s.indices.map { index => - StructField(curId.getAndIncrement().toString, - getSupportedDataType(curId, s.fields(index).dataType), s.fields(index).nullable, - s.fields(index).metadata) - }) - mapping.put(s, newStructType) - newStructType - case a@ArrayType(elementType, nullable) => - val newArrayType = - ArrayType(getSupportedDataType(curId, elementType), nullable) - mapping.put(a, newArrayType) - newArrayType - case m@MapType(keyType, valueType, nullable) => - val newKeyType = getSupportedDataType(curId, keyType) - val newValueType = getSupportedDataType(curId, valueType) - val mapType = MapType(newKeyType, newValueType, nullable) - mapping.put(m, mapType) - mapType - case d: DecimalType if d.scale < 0 => - val newType = if (d.precision <= Decimal.MAX_INT_DIGITS) { - IntegerType - } else { - LongType - } - newType - case _ => - dataType - } - } - // We want to change the original schema to have the new names as well private def sanitizeColumnNames(originalSchema: Seq[Attribute], schemaToCopyNamesFrom: Seq[Attribute]): Seq[Attribute] = { @@ -1379,27 +1294,8 @@ protected class ParquetCachedBatchSerializer extends GpuCachedBatchSerializer wi cachedAttributes: Seq[Attribute], requestedAttributes: Seq[Attribute] = Seq.empty): (Seq[Attribute], Seq[Attribute]) = { - // We only handle CalendarIntervalType, Decimals and NullType ATM convert it to a supported type - val curId = new AtomicLong() - val newCachedAttributes = cachedAttributes.map { - attribute => val name = s"_col${curId.getAndIncrement()}" - attribute.dataType match { - case CalendarIntervalType => - AttributeReference(name, intervalStructType, - attribute.nullable, metadata = attribute.metadata)(attribute.exprId) - .asInstanceOf[Attribute] - case NullType => - AttributeReference(name, DataTypes.ByteType, - nullable = true, metadata = - attribute.metadata)(attribute.exprId).asInstanceOf[Attribute] - case StructType(_) | ArrayType(_, _) | MapType(_, _, _) | DecimalType() => - AttributeReference(name, - getSupportedDataType(curId, attribute.dataType), - attribute.nullable, attribute.metadata)(attribute.exprId) - case _ => - attribute.withName(name) - } - } + val newCachedAttributes = + PCBSSchemaHelper.getSupportedSchemaFromUnsupported(cachedAttributes) val newRequestedAttributes = getSelectedSchemaFromCachedSchema(requestedAttributes, newCachedAttributes) diff --git a/sql-plugin/src/main/311+-all/scala/org/apache/spark/sql/rapids/PCBSSchemaHelper.scala b/sql-plugin/src/main/311+-all/scala/org/apache/spark/sql/rapids/PCBSSchemaHelper.scala new file mode 100644 index 00000000000..a35dca1e13e --- /dev/null +++ b/sql-plugin/src/main/311+-all/scala/org/apache/spark/sql/rapids/PCBSSchemaHelper.scala @@ -0,0 +1,136 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids + +import java.util.concurrent.atomic.AtomicLong + +import com.nvidia.spark.rapids.shims.GpuTypeShims + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.types.{ArrayType, AtomicType, ByteType, CalendarIntervalType, DataType, DataTypes, Decimal, DecimalType, IntegerType, LongType, MapType, NullType, StructField, StructType, UserDefinedType} + +object PCBSSchemaHelper { + val calendarIntervalStructType = new StructType() + .add("_days", IntegerType) + .add("_months", IntegerType) + .add("_ms", LongType) + + /** + * This method checks if the datatype passed is officially supported by parquet. + * + * Please refer to https://github.com/apache/parquet-format/blob/master/LogicalTypes.md to see + * the what types are supported by parquet + */ + def isTypeSupportedByParquet(dataType: DataType): Boolean = { + dataType match { + case CalendarIntervalType | NullType => false + case s: StructType => s.forall(field => isTypeSupportedByParquet(field.dataType)) + case ArrayType(elementType, _) => isTypeSupportedByParquet(elementType) + case MapType(keyType, valueType, _) => isTypeSupportedByParquet(keyType) && + isTypeSupportedByParquet(valueType) + case d: DecimalType if d.scale < 0 => false + //Atomic Types + case _: AtomicType => true + case _ => false + } + } + + def isTypeSupportedByColumnarSparkParquetWriter(dataType: DataType): Boolean = { + // Columnar writer in Spark only supports AtomicTypes ATM + dataType match { + case _: AtomicType => true + case other if GpuTypeShims.isParquetColumnarWriterSupportedForType(other) => true + case _ => false + } + } + + /** + * This method converts types that parquet doesn't recognize to types that Parquet understands. + * e.g. CalendarIntervalType is converted to a struct with a struct of two integer types and a + * long type. + */ + def getSupportedDataType(dataType: DataType): DataType = { + dataType match { + case CalendarIntervalType => + calendarIntervalStructType + case NullType => + ByteType + case s: StructType => + val newStructType = StructType( + s.indices.map { index => + StructField(s.fields(index).name, + getSupportedDataType(s.fields(index).dataType), + s.fields(index).nullable, s.fields(index).metadata) + }) + newStructType + case _@ArrayType(elementType, nullable) => + val newArrayType = + ArrayType(getSupportedDataType(elementType), nullable) + newArrayType + case _@MapType(keyType, valueType, nullable) => + val newKeyType = getSupportedDataType(keyType) + val newValueType = getSupportedDataType(valueType) + val mapType = MapType(newKeyType, newValueType, nullable) + mapType + case d: DecimalType if d.scale < 0 => + val newType = if (d.precision <= Decimal.MAX_INT_DIGITS) { + IntegerType + } else { + LongType + } + newType + case _: AtomicType => dataType + case o => + throw new IllegalArgumentException(s"We don't support ${o.typeName}") + } + } + + /** + * There are certain types that are not supported by Parquet. This method converts the schema + * of those types to something parquet understands e.g. CalendarIntervalType will be converted + * to an attribute with {@link calendarIntervalStructType} as type + */ + def getSupportedSchemaFromUnsupported(cachedAttributes: Seq[Attribute]): Seq[Attribute] = { + // We convert CalendarIntervalType, UDT and NullType ATM convert it to a supported type + val curId = new AtomicLong() + cachedAttributes.map { + attribute => + val name = s"_col${curId.getAndIncrement()}" + attribute.dataType match { + case CalendarIntervalType => + AttributeReference(name, calendarIntervalStructType, + attribute.nullable, metadata = attribute.metadata)(attribute.exprId) + .asInstanceOf[Attribute] + case NullType => + AttributeReference(name, DataTypes.ByteType, + nullable = true, metadata = + attribute.metadata)(attribute.exprId).asInstanceOf[Attribute] + case StructType(_) | ArrayType(_, _) | MapType(_, _, _) | DecimalType() => + AttributeReference(name, + getSupportedDataType(attribute.dataType), + attribute.nullable, attribute.metadata)(attribute.exprId) + case udt: UserDefinedType[_] => + AttributeReference(name, + getSupportedDataType(udt.sqlType), + attribute.nullable, attribute.metadata)(attribute.exprId) + case _ => + attribute.withName(name) + } + } + } + +} diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala index 077ec30db69..37ddd8b910d 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala @@ -914,7 +914,7 @@ object FileFormatChecks { * The namedChecks map can be used to provide checks for specific groups of expressions. */ class ExecChecks private( - check: TypeSig, + val check: TypeSig, sparkSig: TypeSig, val namedChecks: Map[String, InputCheck], override val shown: Boolean = true)