-
Notifications
You must be signed in to change notification settings - Fork 230
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FEA] support json to struct function #8174
Changes from 8 commits
ab4fd03
c91be11
1488f78
4abc067
0fd5b0b
4050d91
a1ee3d1
de2d793
05e5ba4
906b36b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,12 +16,18 @@ | |
|
||
package org.apache.spark.sql.rapids | ||
|
||
import scala.collection.mutable.{ArrayBuffer, Set} | ||
|
||
import ai.rapids.cudf | ||
import com.nvidia.spark.rapids.{GpuColumnVector, GpuUnaryExpression} | ||
import com.nvidia.spark.rapids.{GpuColumnVector, GpuScalar, GpuUnaryExpression} | ||
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} | ||
import com.nvidia.spark.rapids.GpuCast.doCast | ||
import com.nvidia.spark.rapids.RapidsPluginImplicits.AutoCloseableProducingSeq | ||
import com.nvidia.spark.rapids.jni.MapUtils | ||
|
||
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, NullIntolerant, TimeZoneAwareExpression} | ||
import org.apache.spark.sql.types.{AbstractDataType, DataType, StringType} | ||
// import org.apache.spark.sql.types.{AbstractDataType, DataType, MapType, StringType, StructType} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: delete commented out code. |
||
import org.apache.spark.sql.types._ | ||
|
||
case class GpuJsonToStructs( | ||
schema: DataType, | ||
|
@@ -30,8 +36,156 @@ case class GpuJsonToStructs( | |
timeZoneId: Option[String] = None) | ||
extends GpuUnaryExpression with TimeZoneAwareExpression with ExpectsInputTypes | ||
with NullIntolerant { | ||
|
||
private def cleanAndConcat(input: cudf.ColumnVector): (cudf.ColumnVector, cudf.ColumnVector) ={ | ||
withResource(cudf.Scalar.fromString("{}")) { emptyRow => | ||
val stripped = withResource(cudf.Scalar.fromString(" ")) { space => | ||
input.strip(space) | ||
} | ||
withResource(stripped) { stripped => | ||
val isNullOrEmptyInput = withResource(input.isNull) { isNull => | ||
val isEmpty = withResource(stripped.getCharLengths) { lengths => | ||
withResource(cudf.Scalar.fromInt(0)) { zero => | ||
lengths.lessOrEqualTo(zero) | ||
} | ||
} | ||
withResource(isEmpty) { isEmpty => | ||
isNull.binaryOp(cudf.BinaryOp.NULL_LOGICAL_OR, isEmpty, cudf.DType.BOOL8) | ||
} | ||
} | ||
closeOnExcept(isNullOrEmptyInput) { _ => | ||
withResource(isNullOrEmptyInput.ifElse(emptyRow, stripped)) { cleaned => | ||
revans2 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
withResource(cudf.Scalar.fromString("\n")) { lineSep => | ||
gerashegalov marked this conversation as resolved.
Show resolved
Hide resolved
|
||
withResource(cudf.Scalar.fromString("\r")) { returnSep => | ||
withResource(cleaned.stringContains(lineSep)) { inputHas => | ||
withResource(inputHas.any()) { anyLineSep => | ||
if (anyLineSep.isValid && anyLineSep.getBoolean) { | ||
throw new IllegalArgumentException("We cannot currently support parsing " + | ||
"JSON that contains a line separator in it") | ||
} | ||
} | ||
} | ||
withResource(cleaned.stringContains(returnSep)) { inputHas => | ||
withResource(inputHas.any()) { anyReturnSep => | ||
if (anyReturnSep.isValid && anyReturnSep.getBoolean) { | ||
throw new IllegalArgumentException("We cannot currently support parsing " + | ||
"JSON that contains a carriage return in it") | ||
} | ||
} | ||
} | ||
} | ||
gerashegalov marked this conversation as resolved.
Show resolved
Hide resolved
|
||
(isNullOrEmptyInput, cleaned.joinStrings(lineSep, emptyRow)) | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
// Process a sequence of field names. If there are duplicated field names, we only keep the field | ||
// name with the largest index in the sequence, for others, replace the field names with null. | ||
// Example: | ||
// Input = [("a", StringType), ("b", StringType), ("a", IntegerType)] | ||
// Output = [(null, StringType), ("b", StringType), ("a", IntegerType)] | ||
private def processFieldNames(names: Seq[(String, DataType)]): Seq[(String, DataType)] = { | ||
val existingNames = Set[String]() | ||
names.foldRight(Seq[(String, DataType)]())((elem, acc) => { | ||
val (name, dtype) = elem | ||
if (existingNames(name)) (null, dtype)+:acc else {existingNames += name; (name, dtype)+:acc}}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: could we make the formatting less dense so it is simpler to read?
gerashegalov marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
private def getSparkType(col: cudf.ColumnView): DataType = { | ||
col.getType match { | ||
case cudf.DType.INT8 | cudf.DType.UINT8 => ByteType | ||
case cudf.DType.INT16 | cudf.DType.UINT16 => ShortType | ||
case cudf.DType.INT32 | cudf.DType.UINT32 => IntegerType | ||
case cudf.DType.INT64 | cudf.DType.UINT64 => LongType | ||
case cudf.DType.FLOAT32 => FloatType | ||
case cudf.DType.FLOAT64 => DoubleType | ||
case cudf.DType.BOOL8 => BooleanType | ||
case cudf.DType.STRING => StringType | ||
case cudf.DType.LIST => ArrayType(getSparkType(col.getChildColumnView(0))) | ||
case cudf.DType.STRUCT => | ||
val structFields = ArrayBuffer.empty[StructField] | ||
(0 until col.getNumChildren).foreach { i => | ||
val child = col.getChildColumnView(i) | ||
structFields += StructField("", getSparkType(child)) | ||
} | ||
gerashegalov marked this conversation as resolved.
Show resolved
Hide resolved
|
||
StructType(structFields) | ||
case t => throw new IllegalArgumentException( | ||
s"GpuJsonToStructs currently cannot process CUDF column of type $t.") | ||
} | ||
} | ||
|
||
override protected def doColumnar(input: GpuColumnVector): cudf.ColumnVector = { | ||
MapUtils.extractRawMapFromJsonString(input.getBase) | ||
schema match { | ||
case _: MapType => | ||
MapUtils.extractRawMapFromJsonString(input.getBase) | ||
case struct: StructType => { | ||
// We cannot handle all corner cases with this right now. The parser just isn't | ||
// good enough, but we will try to handle a few common ones. | ||
val numRows = input.getRowCount.toInt | ||
|
||
// Step 1: verify and preprocess the data to clean it up and normalize a few things | ||
// Step 2: Concat the data into a single buffer | ||
val (isNullOrEmpty, combined) = cleanAndConcat(input.getBase) | ||
withResource(isNullOrEmpty) { isNullOrEmpty => | ||
// Step 3: copy the data back to the host so we can parse it. | ||
val combinedHost = withResource(combined) { combined => | ||
combined.copyToHost() | ||
} | ||
// Step 4: Have cudf parse the JSON data | ||
val (names, rawTable) = withResource(combinedHost) { combinedHost => | ||
val data = combinedHost.getData | ||
val start = combinedHost.getStartListOffset(0) | ||
val end = combinedHost.getEndListOffset(0) | ||
val length = end - start | ||
|
||
withResource(cudf.Table.readJSON(cudf.JSONOptions.DEFAULT, data, start, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is having CUDF do name and type inference. Is that really what we want? Should we do this like we do for regular JSON parsing? (Never mind turns out we do the same thing in the JSON reader??? What are we doing that? It is a huge waste of memory. Can we please file a follow on issue to fix it both here and in the JSON reader. Bonus points if we can combine the code reader code together. |
||
length)) { tableWithMeta => | ||
val names = tableWithMeta.getColumnNames | ||
(names, tableWithMeta.releaseTable()) | ||
} | ||
} | ||
|
||
// process duplicated field names in input struct schema | ||
val fieldNames = processFieldNames(struct.fields.map { field => | ||
(field.name, field.dataType)}) | ||
|
||
withResource(rawTable) { rawTable => | ||
// Step 5: verify that the data looks correct | ||
if (rawTable.getRowCount != numRows) { | ||
throw new IllegalStateException("The input data didn't parse correctly and we read " + | ||
s"a different number of rows than was expected. Expected $numRows, " + | ||
s"but got ${rawTable.getRowCount}") | ||
} | ||
|
||
// Step 6: get the data based on input struct schema | ||
val columns = fieldNames.safeMap { case (name, dtype) => | ||
val i = names.indexOf(name) | ||
if (i == -1) { | ||
GpuColumnVector.columnVectorFromNull(numRows, dtype) | ||
} else { | ||
val col = rawTable.getColumn(i) | ||
doCast(col, getSparkType(col), dtype, false, false, false) | ||
revans2 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
} | ||
|
||
revans2 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
// Step 7: turn the data into a Struct | ||
withResource(columns) { columns => | ||
withResource(cudf.ColumnVector.makeStruct(columns: _*)) { structData => | ||
// Step 8: put nulls back in for nulls and empty strings | ||
withResource(GpuScalar.from(null, struct)) { nullVal => | ||
isNullOrEmpty.ifElse(nullVal, structData) | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
case _ => throw new IllegalArgumentException( | ||
s"GpuJsonToStructs currently does not support schema of type $schema.") | ||
} | ||
} | ||
|
||
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I only see tests for String, Long, and Struct. If we say that we support the other types we really should have tests for them. This needs to include things like STRUCT of STRUCTs and STURCTs of LISTS.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we say that we support all of the types in our meta object, then we need tests for all of the data types that JSON supports in Spark
https://github.com/apache/spark/blob/4a238cd9d8e80eed06732fc52b1456cb5ece6652/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala#L193-L385
I personally would rather see us start with a few simple types and add more as we add tests for them. So if we have tests for String, Int, array and struct, then we should only say that we support those types. We can add in support for boolean, byte, short, long, decimal (which needs to include multiple precision and scale types), Float, Double, Timestamp, TimestampNTZ, Date, Binary, CalendarInterval, YearMonthInterval, DayTimeInterval, UDT and NullTypes when a customer/management asks for them or when we have tests that show that they are working correctly.