Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use new getJsonObject kernel for json_tuple #10635

Merged
merged 9 commits into from
Apr 25, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@

package com.nvidia.spark.rapids

import ai.rapids.cudf.{GetJsonObjectOptions,Scalar}
import com.nvidia.spark.rapids.Arm._
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.RmmRapidsRetryIterator.{splitSpillableInHalfByRows, withRetry}
import com.nvidia.spark.rapids.jni.JSONUtils
import com.nvidia.spark.rapids.shims.ShimExpression

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
Expand Down Expand Up @@ -59,26 +59,28 @@ case class GpuJsonTuple(children: Seq[Expression]) extends GpuGenerator
val json = inputBatch.column(generatorOffset).asInstanceOf[GpuColumnVector].getBase
val schema = Array.fill[DataType](fieldExpressions.length)(StringType)

val fieldScalars = fieldExpressions.safeMap { field =>
val fieldInstructions = fieldExpressions.map { field =>
withResourceIfAllowed(field.columnarEvalAny(inputBatch)) {
case fieldScalar: GpuScalar =>
// Specials characters like '.', '[', ']' are not supported in field names
Scalar.fromString("$." + fieldScalar.getBase.getJavaString)
val fieldString = fieldScalar.getBase.getJavaString
val key = new JSONUtils.PathInstructionJni(
JSONUtils.PathInstructionType.KEY, "", -1)
val named = new JSONUtils.PathInstructionJni(
JSONUtils.PathInstructionType.NAMED, fieldString, -1)
Array(key, named)
case _ => throw new UnsupportedOperationException(s"JSON field must be a scalar value")
}
}

withResource(fieldScalars) { fieldScalars =>
withResource(fieldScalars.safeMap(field => json.getJSONObject(field,
GetJsonObjectOptions.builder().allowSingleQuotes(true).build()))) { resultCols =>
val generatorCols = resultCols.safeMap(_.incRefCount).zip(schema).safeMap {
case (col, dataType) => GpuColumnVector.from(col, dataType)
}
val nonGeneratorCols = (0 until generatorOffset).safeMap { i =>
inputBatch.column(i).asInstanceOf[GpuColumnVector].incRefCount
}
new ColumnarBatch((nonGeneratorCols ++ generatorCols).toArray, inputBatch.numRows)
withResource(fieldInstructions.safeMap(field => JSONUtils.getJsonObject(json, 2, field))) {
resultCols =>
val generatorCols = resultCols.safeMap(_.incRefCount).zip(schema).safeMap {
case (col, dataType) => GpuColumnVector.from(col, dataType)
}
val nonGeneratorCols = (0 until generatorOffset).safeMap { i =>
inputBatch.column(i).asInstanceOf[GpuColumnVector].incRefCount
}
new ColumnarBatch((nonGeneratorCols ++ generatorCols).toArray, inputBatch.numRows)
}
}
}
Expand Down
Loading