Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add UDT support to ParquetCachedBatchSerializer (CPU) #4955

Merged
merged 6 commits into from
Mar 21, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion integration_tests/src/main/python/cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from spark_session import with_cpu_session, with_gpu_session, is_before_spark_330
from join_test import create_df
from marks import incompat, allow_non_gpu, ignore_order
import pyspark.mllib.linalg as mllib
import pyspark.ml.linalg as ml

enable_vectorized_confs = [{"spark.sql.inMemoryColumnarStorage.enableVectorizedReader": "true"},
{"spark.sql.inMemoryColumnarStorage.enableVectorizedReader": "false"}]
Expand Down Expand Up @@ -286,6 +288,19 @@ def helper(spark):

assert_gpu_and_cpu_are_equal_collect(helper)

def test_cache_udt():
def fun(spark):
df = spark.sparkContext.parallelize([
(mllib.DenseVector([1, ]), ml.DenseVector([1, ])),
(mllib.SparseVector(1, [0, ], [1, ]), ml.SparseVector(1, [0, ], [1, ]))
]).toDF(["mllib_v", "ml_v"])
df.cache().count()
return df.selectExpr("mllib_v", "ml_v").collect()
cpu_result = with_cpu_session(fun)
gpu_result = with_gpu_session(fun)
# assert_gpu_and_cpu_are_equal_collect method doesn't handle UDT so we just write a single
# statement here to compare
assert cpu_result == gpu_result, "not equal"

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Spark3.3.0')
@pytest.mark.parametrize('enable_vectorized_conf', enable_vectorized_confs, ids=idfn)
Expand All @@ -296,4 +311,3 @@ def test_func(spark):
df.cache().count()
return df.selectExpr("b", "a")
assert_gpu_and_cpu_are_equal_collect(test_func, enable_vectorized_conf)

Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package com.nvidia.spark

import com.nvidia.spark.rapids.{DataFromReplacementRule, GpuExec, RapidsConf, RapidsMeta, ShimLoader, SparkPlanMeta}
import com.nvidia.spark.rapids.{DataFromReplacementRule, ExecChecks, GpuExec, RapidsConf, RapidsMeta, ShimLoader, SparkPlanMeta}

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
Expand All @@ -25,7 +25,7 @@ import org.apache.spark.sql.columnar.{CachedBatch, CachedBatchSerializer}
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
import org.apache.spark.sql.rapids.shims.GpuInMemoryTableScanExec
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.storage.StorageLevel

Expand All @@ -46,6 +46,22 @@ class InMemoryTableScanMeta(
extends SparkPlanMeta[InMemoryTableScanExec](imts, conf, parent, rule) {

override def tagPlanForGpu(): Unit = {
def stringifyTypeAttributeMap(groupedByType: Map[DataType, Set[String]]): String = {
groupedByType.map { case (dataType, nameSet) =>
dataType + " " + nameSet.mkString("[", ", ", "]")
}.mkString(", ")
}

val supportedTypeSig = rule.getChecks.get.asInstanceOf[ExecChecks]
val unsupportedTypes: Map[DataType, Set[String]] = imts.relation.output
.filterNot(attr => supportedTypeSig.check.isSupportedByPlugin(attr.dataType))
.groupBy(_.dataType)
.mapValues(_.map(_.name).toSet)

val msgFormat = "unsupported data types in output: %s"
if (unsupportedTypes.nonEmpty) {
willNotWorkOnGpu(msgFormat.format(stringifyTypeAttributeMap(unsupportedTypes)))
}
if (!imts.relation.cacheBuilder.serializer
.isInstanceOf[com.nvidia.spark.ParquetCachedBatchSerializer]) {
willNotWorkOnGpu("ParquetCachedBatchSerializer is not being used")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package com.nvidia.spark.rapids.shims
import java.io.{InputStream, IOException}
import java.lang.reflect.Method
import java.nio.ByteBuffer
import java.util.concurrent.atomic.AtomicLong

import scala.collection.JavaConverters._
import scala.collection.mutable
Expand Down Expand Up @@ -58,6 +57,7 @@ import org.apache.spark.sql.execution.datasources.parquet.rapids.shims.{ParquetR
import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, WritableColumnVector}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy
import org.apache.spark.sql.rapids.PCBSSchemaHelper
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.storage.StorageLevel
Expand Down Expand Up @@ -267,18 +267,8 @@ protected class ParquetCachedBatchSerializer extends GpuCachedBatchSerializer wi

override def supportsColumnarOutput(schema: StructType): Boolean = schema.fields.forall { f =>
// only check spark b/c if we are on the GPU then we will be calling the gpu method regardless
isTypeSupportedByColumnarSparkParquetWriter(f.dataType) || f.dataType == DataTypes.NullType
}

private def isTypeSupportedByColumnarSparkParquetWriter(dataType: DataType): Boolean = {
// Columnar writer in Spark only supports AtomicTypes ATM
dataType match {
case TimestampType | StringType | BooleanType | DateType | BinaryType |
DoubleType | FloatType | ByteType | IntegerType | LongType | ShortType => true
case _: DecimalType => true
case other if GpuTypeShims.isParquetColumnarWriterSupportedForType(other) => true
case _ => false
}
PCBSSchemaHelper.isTypeSupportedByColumnarSparkParquetWriter(f.dataType) ||
f.dataType == DataTypes.NullType
}

def isSchemaSupportedByCudf(schema: Seq[Attribute]): Boolean = {
Expand All @@ -294,24 +284,6 @@ protected class ParquetCachedBatchSerializer extends GpuCachedBatchSerializer wi
}
}

/**
* This method checks if the datatype passed is officially supported by parquet.
*
* Please refer to https://github.com/apache/parquet-format/blob/master/LogicalTypes.md to see
* the what types are supported by parquet
*/
def isTypeSupportedByParquet(dataType: DataType): Boolean = {
dataType match {
case CalendarIntervalType | NullType => false
case s: StructType => s.forall(field => isTypeSupportedByParquet(field.dataType))
case ArrayType(elementType, _) => isTypeSupportedByParquet(elementType)
case MapType(keyType, valueType, _) => isTypeSupportedByParquet(keyType) &&
isTypeSupportedByParquet(valueType)
case d: DecimalType if d.scale < 0 => false
case _ => true
}
}

/**
* Convert an `RDD[ColumnarBatch]` into an `RDD[CachedBatch]` in preparation for caching the data.
* This method uses Parquet Writer on the GPU to write the cached batch
Expand Down Expand Up @@ -645,7 +617,8 @@ protected class ParquetCachedBatchSerializer extends GpuCachedBatchSerializer wi
dataType match {
case s@StructType(_) =>
val listBuffer = new ListBuffer[InternalRow]()
val supportedSchema = mapping(dataType).asInstanceOf[StructType]
val supportedSchema =
PCBSSchemaHelper.getOriginalDataType(dataType).asInstanceOf[StructType]
arrayData.foreach(supportedSchema, (_, data) => {
val structRow =
handleStruct(data.asInstanceOf[InternalRow], s, s)
Expand Down Expand Up @@ -750,7 +723,7 @@ protected class ParquetCachedBatchSerializer extends GpuCachedBatchSerializer wi
withResource(ParquetFileReader.open(inputFile, options)) { parquetFileReader =>
val parquetSchema = parquetFileReader.getFooter.getFileMetaData.getSchema
val hasUnsupportedType = origCacheSchema.exists { field =>
!isTypeSupportedByParquet(field.dataType)
!PCBSSchemaHelper.isTypeSupportedByParquet(field.dataType)
}

val unsafeRows = new ArrayBuffer[InternalRow]
Expand Down Expand Up @@ -808,16 +781,16 @@ protected class ParquetCachedBatchSerializer extends GpuCachedBatchSerializer wi
newRow: InternalRow): Unit = {
schema.indices.foreach { index =>
val dataType = schema(index).dataType
if (mapping.contains(dataType) || dataType == CalendarIntervalType ||
dataType == NullType ||
if (PCBSSchemaHelper.wasOriginalTypeConverted(dataType) ||
dataType == CalendarIntervalType || dataType == NullType ||
(dataType.isInstanceOf[DecimalType]
&& dataType.asInstanceOf[DecimalType].scale < 0)) {
if (row.isNullAt(index)) {
newRow.setNullAt(index)
} else {
dataType match {
case s@StructType(_) =>
val supportedSchema = mapping(dataType)
val supportedSchema = PCBSSchemaHelper.getOriginalDataType(dataType)
.asInstanceOf[StructType]
val structRow =
handleStruct(row.getStruct(index, supportedSchema.size), s, s)
Expand Down Expand Up @@ -1073,11 +1046,6 @@ protected class ParquetCachedBatchSerializer extends GpuCachedBatchSerializer wi
conf
}

private val intervalStructType = new StructType()
.add("_days", IntegerType)
.add("_months", IntegerType)
.add("_ms", LongType)

def getBytesAllowedPerBatch(conf: SQLConf): Long = {
val gpuBatchSize = RapidsConf.GPU_BATCH_SIZE_BYTES.get(conf)
// we are rough estimating 0.5% as meta_data_size. we can do better estimation in future
Expand Down Expand Up @@ -1129,7 +1097,7 @@ protected class ParquetCachedBatchSerializer extends GpuCachedBatchSerializer wi

// is there a type that spark doesn't support by default in the schema?
val hasUnsupportedType: Boolean = origCachedAttributes.exists { attribute =>
!isTypeSupportedByParquet(attribute.dataType)
!PCBSSchemaHelper.isTypeSupportedByParquet(attribute.dataType)
}

def getIterator: Iterator[InternalRow] = {
Expand Down Expand Up @@ -1171,16 +1139,17 @@ protected class ParquetCachedBatchSerializer extends GpuCachedBatchSerializer wi
newRow: InternalRow): Unit = {
schema.indices.foreach { index =>
val dataType = schema(index).dataType
if (mapping.contains(dataType) || dataType == CalendarIntervalType ||
dataType == NullType ||
if (PCBSSchemaHelper.wasOriginalTypeConverted(dataType) ||
dataType == CalendarIntervalType || dataType == NullType ||
(dataType.isInstanceOf[DecimalType]
&& dataType.asInstanceOf[DecimalType].scale < 0)) {
if (row.isNullAt(index)) {
newRow.setNullAt(index)
} else {
dataType match {
case s@StructType(_) =>
val newSchema = mapping(dataType).asInstanceOf[StructType]
val newSchema = PCBSSchemaHelper.getOriginalDataType(dataType)
.asInstanceOf[StructType]
val structRow =
handleStruct(row.getStruct(index, s.fields.length), s, newSchema)
newRow.update(index, structRow)
Expand Down Expand Up @@ -1327,46 +1296,6 @@ protected class ParquetCachedBatchSerializer extends GpuCachedBatchSerializer wi

}

val mapping = new mutable.HashMap[DataType, DataType]()

def getSupportedDataType(curId: AtomicLong, dataType: DataType): DataType = {
dataType match {
case CalendarIntervalType =>
intervalStructType
case NullType =>
ByteType
case s: StructType =>
val newStructType = StructType(
s.indices.map { index =>
StructField(curId.getAndIncrement().toString,
getSupportedDataType(curId, s.fields(index).dataType), s.fields(index).nullable,
s.fields(index).metadata)
})
mapping.put(s, newStructType)
newStructType
case a@ArrayType(elementType, nullable) =>
val newArrayType =
ArrayType(getSupportedDataType(curId, elementType), nullable)
mapping.put(a, newArrayType)
newArrayType
case m@MapType(keyType, valueType, nullable) =>
val newKeyType = getSupportedDataType(curId, keyType)
val newValueType = getSupportedDataType(curId, valueType)
val mapType = MapType(newKeyType, newValueType, nullable)
mapping.put(m, mapType)
mapType
case d: DecimalType if d.scale < 0 =>
val newType = if (d.precision <= Decimal.MAX_INT_DIGITS) {
IntegerType
} else {
LongType
}
newType
case _ =>
dataType
}
}

// We want to change the original schema to have the new names as well
private def sanitizeColumnNames(originalSchema: Seq[Attribute],
schemaToCopyNamesFrom: Seq[Attribute]): Seq[Attribute] = {
Expand All @@ -1379,27 +1308,8 @@ protected class ParquetCachedBatchSerializer extends GpuCachedBatchSerializer wi
cachedAttributes: Seq[Attribute],
requestedAttributes: Seq[Attribute] = Seq.empty): (Seq[Attribute], Seq[Attribute]) = {

// We only handle CalendarIntervalType, Decimals and NullType ATM convert it to a supported type
val curId = new AtomicLong()
val newCachedAttributes = cachedAttributes.map {
attribute => val name = s"_col${curId.getAndIncrement()}"
attribute.dataType match {
case CalendarIntervalType =>
AttributeReference(name, intervalStructType,
attribute.nullable, metadata = attribute.metadata)(attribute.exprId)
.asInstanceOf[Attribute]
case NullType =>
AttributeReference(name, DataTypes.ByteType,
nullable = true, metadata =
attribute.metadata)(attribute.exprId).asInstanceOf[Attribute]
case StructType(_) | ArrayType(_, _) | MapType(_, _, _) | DecimalType() =>
AttributeReference(name,
getSupportedDataType(curId, attribute.dataType),
attribute.nullable, attribute.metadata)(attribute.exprId)
case _ =>
attribute.withName(name)
}
}
val newCachedAttributes =
PCBSSchemaHelper.getSupportedSchemaFromUnsupported(cachedAttributes)

val newRequestedAttributes =
getSelectedSchemaFromCachedSchema(requestedAttributes, newCachedAttributes)
Expand Down
Loading