Skip to content

Commit

Permalink
[SPARK-49074][SQL] Fix variant with df.cache()
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Currently, the `actualSize` method of the `VARIANT` `columnType` isn't overridden, so we use the default size of 2kb for the `actualSize`. We should define `actualSize` so the cached variant column can correctly be written to the byte buffer.

Currently, if the avg per-variant size is greater than 2KB and the total column size is greater than 128KB (the default initial buffer size), an exception will be (incorrectly) thrown.

### Why are the changes needed?

to fix caching larger variants (in df.cache()), such as the ones included in the UTs.

### Does this PR introduce _any_ user-facing change?

no

### How was this patch tested?

added UT

### Was this patch authored or co-authored using generative AI tooling?

no

Closes apache#47559 from richardc-db/fix_variant_cache.

Authored-by: Richard Chen <r.chen@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
richardc-db authored and cloud-fan committed Aug 1, 2024
1 parent 06ed91a commit bf3ad7e
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,12 @@ private[columnar] object VARIANT
/** Chosen to match the default size set in `VariantType`. */
override def defaultSize: Int = 2048

override def actualSize(row: InternalRow, ordinal: Int): Int = {
val v = getField(row, ordinal)
// 4 bytes each for the integers representing the 'value' and 'metadata' lengths.
8 + v.getValue().length + v.getMetadata().length
}

override def getField(row: InternalRow, ordinal: Int): VariantVal = row.getVariant(ordinal)

override def setField(row: InternalRow, ordinal: Int, value: VariantVal): Unit =
Expand Down
53 changes: 53 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,21 @@ class VariantSuite extends QueryTest with SharedSparkSession with ExpressionEval
checkAnswer(df, expected.collect())
}

test("variant with many keys in a cached row-based df") {
// The initial size of the buffer backing a cached dataframe column is 128KB.
// See `ColumnBuilder`.
val numKeys = 128 * 1024
var keyIterator = (0 until numKeys).iterator
val entries = Array.fill(numKeys)(s"""\"${keyIterator.next()}\": \"test\"""")
val jsonStr = s"{${entries.mkString(", ")}}"
val query = s"""select parse_json('${jsonStr}') v from range(0, 10)"""
val df = spark.sql(query)
df.cache()

val expected = spark.sql(query)
checkAnswer(df, expected.collect())
}

test("struct of variant in a cached row-based df") {
val query = """select named_struct(
'v', parse_json(format_string('{\"a\": %s}', id)),
Expand Down Expand Up @@ -680,6 +695,21 @@ class VariantSuite extends QueryTest with SharedSparkSession with ExpressionEval
checkAnswer(df, expected.collect())
}

test("array variant with many keys in a cached row-based df") {
// The initial size of the buffer backing a cached dataframe column is 128KB.
// See `ColumnBuilder`.
val numKeys = 128 * 1024
var keyIterator = (0 until numKeys).iterator
val entries = Array.fill(numKeys)(s"""\"${keyIterator.next()}\": \"test\"""")
val jsonStr = s"{${entries.mkString(", ")}}"
val query = s"""select array(parse_json('${jsonStr}')) v from range(0, 10)"""
val df = spark.sql(query)
df.cache()

val expected = spark.sql(query)
checkAnswer(df, expected.collect())
}

test("map of variant in a cached row-based df") {
val query = """select map(
'v', parse_json(format_string('{\"a\": %s}', id)),
Expand Down Expand Up @@ -711,6 +741,29 @@ class VariantSuite extends QueryTest with SharedSparkSession with ExpressionEval
}
}

test("variant with many keys in a cached column-based df") {
withTable("t") {
// The initial size of the buffer backing a cached dataframe column is 128KB.
// See `ColumnBuilder`.
val numKeys = 128 * 1024
var keyIterator = (0 until numKeys).iterator
val entries = Array.fill(numKeys)(s"""\"${keyIterator.next()}\": \"test\"""")
val jsonStr = s"{${entries.mkString(", ")}}"
val query = s"""select named_struct(
'v', parse_json('$jsonStr'),
'null_v', cast(null as variant),
'some_null', case when id % 2 = 0 then parse_json(cast(id as string)) else null end
) v
from range(0, 10)"""
spark.sql(query).write.format("parquet").mode("overwrite").saveAsTable("t")
val df = spark.sql("select * from t")
df.cache()

val expected = spark.sql(query)
checkAnswer(df, expected.collect())
}
}

test("variant_get size") {
val largeKey = "x" * 1000
val df = Seq(s"""{ "$largeKey": {"a" : 1 },
Expand Down

0 comments on commit bf3ad7e

Please sign in to comment.