diff --git a/integration_tests/src/main/python/row_conversion_test.py b/integration_tests/src/main/python/row_conversion_test.py index 92ea05d68be..02617965fb9 100644 --- a/integration_tests/src/main/python/row_conversion_test.py +++ b/integration_tests/src/main/python/row_conversion_test.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2022, NVIDIA CORPORATION. +# Copyright (c) 2020-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -28,7 +28,12 @@ # to be brought back to the CPU (rows) to be returned. # So we just need a very simple operation in the middle that # can be done on the GPU. -def test_row_conversions(): +@pytest.mark.parametrize('override_batch_size_bytes', [None, '4mb', '1kb'], ids=idfn) +def test_row_conversions(override_batch_size_bytes): + conf = {} + if override_batch_size_bytes is not None: + conf["spark.rapids.sql.batchSizeBytes"] = override_batch_size_bytes + gens = [["a", byte_gen], ["b", short_gen], ["c", int_gen], ["d", long_gen], ["e", float_gen], ["f", double_gen], ["g", string_gen], ["h", boolean_gen], ["i", timestamp_gen], ["j", date_gen], ["k", ArrayGen(byte_gen)], @@ -40,7 +45,7 @@ def test_row_conversions(): ["s", null_gen], ["t", decimal_gen_64bit], ["u", decimal_gen_32bit], ["v", decimal_gen_128bit]] assert_gpu_and_cpu_are_equal_collect( - lambda spark : gen_df(spark, gens).selectExpr("*", "a as a_again")) + lambda spark : gen_df(spark, gens).selectExpr("*", "a as a_again"), conf=conf) def test_row_conversions_fixed_width(): gens = [["a", byte_gen], ["b", short_gen], ["c", int_gen], ["d", long_gen], diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java index 1ddc49ffbf1..31864859207 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java @@ -17,6 +17,7 @@ package com.nvidia.spark.rapids; import ai.rapids.cudf.*; +import com.nvidia.spark.Retryable; import com.nvidia.spark.rapids.shims.GpuTypeShims; import org.apache.arrow.memory.ReferenceManager; @@ -234,7 +235,8 @@ public void close() { } } - public static final class GpuColumnarBatchBuilder extends GpuColumnarBatchBuilderBase { + public static final class GpuColumnarBatchBuilder extends GpuColumnarBatchBuilderBase + implements Retryable { private final RapidsHostColumnBuilder[] builders; private ai.rapids.cudf.HostColumnVector[] hostColumns; @@ -266,6 +268,45 @@ public GpuColumnarBatchBuilder(StructType schema, int rows) { } } + /** + * A collection of builders for building up columnar data. + * @param schema the schema of the batch. + * @param rows the maximum number of rows in this batch. + * @param spillableHostBuf single spillable host buffer to slice up among columns + * @param bufferSizes an array of sizes for each column + */ + public GpuColumnarBatchBuilder(StructType schema, int rows, + SpillableHostBuffer spillableHostBuf, long[] bufferSizes) { + fields = schema.fields(); + int len = fields.length; + builders = new RapidsHostColumnBuilder[len]; + boolean success = false; + try (SpillableHostBuffer sBuf = spillableHostBuf; + HostMemoryBuffer hBuf = + RmmRapidsRetryIterator.withRetryNoSplit(() -> sBuf.getHostBuffer());) { + long offset = 0; + for (int i = 0; i < len; i++) { + StructField field = fields[i]; + try (HostMemoryBuffer columnBuffer = hBuf.slice(offset, bufferSizes[i]);) { + offset += bufferSizes[i]; + builders[i] = + new RapidsHostColumnBuilder(convertFrom(field.dataType(), field.nullable()), rows) + .preAllocateBuffers(columnBuffer, 0); + } + } + success = true; + } finally { + if (!success) { + for (RapidsHostColumnBuilder b: builders) { + if (b != null) { + b.close(); + } + } + } + } + } + + @Override public void copyColumnar(ColumnVector cv, int colNum, int rows) { if (builders.length > 0) { @@ -337,6 +378,32 @@ public void close() { } } } + + @Override + public void checkpoint() { + for (RapidsHostColumnBuilder b: builders) { + if (b != null) { + b.checkpoint(); + } + } + } + + @Override + public void restore() { + for (RapidsHostColumnBuilder b: builders) { + if (b != null) { + b.restore(); + } + } + } + + public void setAllowGrowth(boolean enable) { + for (RapidsHostColumnBuilder b: builders) { + if (b != null) { + b.setAllowGrowth(enable); + } + } + } } private static final class ArrowBufReferenceHolder { diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/RapidsHostColumnBuilder.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/RapidsHostColumnBuilder.java index d9d8411643b..729f0f10d67 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/RapidsHostColumnBuilder.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/RapidsHostColumnBuilder.java @@ -16,6 +16,8 @@ package com.nvidia.spark.rapids; +import com.nvidia.spark.Retryable; + import ai.rapids.cudf.ColumnVector; import ai.rapids.cudf.DType; import ai.rapids.cudf.HostColumnVector; @@ -35,17 +37,21 @@ * This is a copy of the cudf HostColumnVector.ColumnBuilder class. * Moving this here to allow for iterating on host memory oom handling. */ -public final class RapidsHostColumnBuilder implements AutoCloseable { +public final class RapidsHostColumnBuilder implements AutoCloseable, Retryable { + private boolean allowGrowth = true; private HostColumnVector.DataType dataType; private DType type; + private long currentInitBufferOffset = 0l; private HostMemoryBuffer data; private HostMemoryBuffer valid; private HostMemoryBuffer offsets; private long nullCount = 0l; - //TODO nullable currently not used + private long checkpointNullCount = 0; private boolean nullable; private long rows; + + private long checkpointRows; private long estimatedRows; private long rowCapacity = 0L; private long validCapacity = 0L; @@ -56,8 +62,12 @@ public final class RapidsHostColumnBuilder implements AutoCloseable { // The value of currentIndex can't exceed Int32.Max. Storing currentIndex as a long is to // adapt HostMemoryBuffer.setXXX, which requires a long offset. private long currentIndex = 0; + private long checkpointCurrentIndex = 0; + // Only for Strings: pointer of the byte (data) buffer private int currentStringByteIndex = 0; + private int checkpointCurrentStringByteIndex = 0; + // Use bit shift instead of multiply to transform row offset to byte offset private int bitShiftBySize = 0; /** @@ -67,6 +77,7 @@ public final class RapidsHostColumnBuilder implements AutoCloseable { private static final int bitShiftByOffset = (int) (Math.log(OFFSET_SIZE) / Math.log(2)); public RapidsHostColumnBuilder(HostColumnVector.DataType dataType, long estimatedRows) { + this.allowGrowth = true; this.dataType = dataType; this.type = dataType.getType(); this.nullable = dataType.isNullable(); @@ -82,6 +93,98 @@ public RapidsHostColumnBuilder(HostColumnVector.DataType dataType, long estimate } } + @Override + public void checkpoint() { + checkpointRows = rows; + checkpointCurrentIndex = currentIndex; + checkpointCurrentStringByteIndex = currentStringByteIndex; + checkpointNullCount = nullCount; + for (RapidsHostColumnBuilder child : childBuilders) { + child.checkpoint(); + } + } + + @Override + public void restore() { + // May need to reset the validity bits + if (nullable && (valid != null) && (currentIndex > checkpointCurrentIndex)) { + for (long i = checkpointCurrentIndex; i < currentIndex; i++) { + resetNullAt(valid, i); + } + } + currentIndex = checkpointCurrentIndex; + currentStringByteIndex = checkpointCurrentStringByteIndex; + nullCount = checkpointNullCount; + rows = checkpointRows; + for (RapidsHostColumnBuilder child : childBuilders) { + child.restore(); + } + } + + private long getInitBufferOffset() { + return this.currentInitBufferOffset; + } + + private void preAllocateOffsets(HostMemoryBuffer initBuffer) { + long neededSize = (estimatedRows + 1) << bitShiftByOffset; + offsets = initBuffer.slice(this.currentInitBufferOffset, neededSize); + offsets.setInt(0, 0); + this.currentInitBufferOffset += neededSize; + } + + private void preAllocateData(HostMemoryBuffer initBuffer, long neededSize) { + data = initBuffer.slice(this.currentInitBufferOffset, neededSize); + this.currentInitBufferOffset += neededSize; + } + + private void preAllocateValidity(HostMemoryBuffer initBuffer) { + // This is the same as ColumnView.getValidityBufferSize + // number of bytes required = Math.ceil(number of bits / 8) + long actualBytes = ((estimatedRows) + 7) >> 3; + // padding to the adding boundary(64 bytes) + long maskBytes = ((actualBytes + 63) >> 6) << 6; + valid = initBuffer.slice(this.currentInitBufferOffset, maskBytes); + this.currentInitBufferOffset += maskBytes; + valid.setMemory(0, valid.getLength(), (byte) 0xFF); + validCapacity = estimatedRows; + } + + public RapidsHostColumnBuilder preAllocateBuffers(HostMemoryBuffer initBuffer, long offset) { + this.allowGrowth = false; + this.currentInitBufferOffset = offset; + + if (this.type == DType.LIST) { + preAllocateOffsets(initBuffer); + } else if (this.type == DType.STRING) { + // Initialize data buffer with 20 bytes per string to match spark default. + preAllocateData(initBuffer, estimatedRows * 20); + preAllocateOffsets(initBuffer); + } else if (this.type == DType.STRUCT) { + // just set rowCapacity below + } else { + preAllocateData(initBuffer, estimatedRows << bitShiftBySize); + } + rowCapacity = estimatedRows; + + // Pre-allocate validity buffer if needed + if (this.nullable) { + preAllocateValidity(initBuffer); + } + + for (int i = 0; i < dataType.getNumChildren(); i++) { + childBuilders.get(i).preAllocateBuffers(initBuffer, this.currentInitBufferOffset); + this.currentInitBufferOffset = childBuilders.get(i).getInitBufferOffset(); + } + return this; + } + + public void setAllowGrowth(boolean enable) { + this.allowGrowth = enable; + for (RapidsHostColumnBuilder child : childBuilders) { + child.setAllowGrowth(enable); + } + } + private void setupNullHandler() { if (this.type == DType.LIST) { this.nullHandler = () -> { @@ -120,9 +223,15 @@ public HostColumnVector build() { for (RapidsHostColumnBuilder childBuilder : childBuilders) { hostColumnVectorCoreList.add(childBuilder.buildNestedInternal()); } - // Aligns the valid buffer size with other buffers in terms of row size, because it grows lazily. if (valid != null) { - growValidBuffer(); + // The valid buffer might have been pre-allocated, but never used. If so, close it. + if (nullCount == 0) { + valid.close(); + valid = null; + } else { + // Aligns the valid buffer size with other buffers in terms of row size, because it grows lazily. + growValidBuffer(); + } } HostColumnVector hostColumnVector = new HostColumnVector(type, rows, Optional.of(nullCount), data, valid, offsets, hostColumnVectorCoreList); @@ -135,9 +244,15 @@ private HostColumnVectorCore buildNestedInternal() { for (RapidsHostColumnBuilder childBuilder : childBuilders) { hostColumnVectorCoreList.add(childBuilder.buildNestedInternal()); } - // Aligns the valid buffer size with other buffers in terms of row size, because it grows lazily. if (valid != null) { - growValidBuffer(); + // The valid buffer might have been pre-allocated, but never used. If so, close it. + if (nullCount == 0) { + valid.close(); + valid = null; + } else { + // Aligns the valid buffer size with other buffers in terms of row size, because it grows lazily. + growValidBuffer(); + } } return new HostColumnVectorCore(type, rows, Optional.of(nullCount), data, valid, offsets, hostColumnVectorCoreList); @@ -186,6 +301,11 @@ private void growValidBuffer() { return; } if (validCapacity < rowCapacity) { + if (!this.allowGrowth) { + throw new RapidsHostColumnOverflow ( + "attempt to add rows beyond preallocated capacity: " + rowCapacity); + } + // This is the same as ColumnView.getValidityBufferSize // number of bytes required = Math.ceil(number of bits / 8) long actualBytes = ((rowCapacity) + 7) >> 3; @@ -220,6 +340,10 @@ private void growFixedWidthBuffersAndRows(int numRows) { data = HostMemoryBuffer.allocate(neededSize << bitShiftBySize); rowCapacity = neededSize; } else if (rows > rowCapacity) { + if (!this.allowGrowth) { + throw new RapidsHostColumnOverflow ( + "attempt to add rows beyond preallocated capacity: " + rowCapacity); + } long neededSize = Math.max(rows, rowCapacity * 2); long newCap = Math.min(neededSize, Integer.MAX_VALUE - 1); data = copyBuffer(HostMemoryBuffer.allocate(newCap << bitShiftBySize), data); @@ -240,6 +364,10 @@ private void growListBuffersAndRows() { offsets.setInt(0, 0); rowCapacity = estimatedRows; } else if (rows > rowCapacity) { + if (!this.allowGrowth) { + throw new RapidsHostColumnOverflow ( + "attempt to add rows beyond preallocated capacity: " + rowCapacity); + } long newCap = Math.min(rowCapacity * 2, Integer.MAX_VALUE - 2); offsets = copyBuffer(HostMemoryBuffer.allocate((newCap + 1) << bitShiftByOffset), offsets); rowCapacity = newCap; @@ -266,6 +394,10 @@ private void growStringBuffersAndRows(int stringLength) { } if (rows > rowCapacity) { + if (!this.allowGrowth) { + throw new RapidsHostColumnOverflow ( + "attempt to add rows beyond preallocated capacity: " + rowCapacity); + } long newCap = Math.min(rowCapacity * 2, Integer.MAX_VALUE - 2); offsets = copyBuffer(HostMemoryBuffer.allocate((newCap + 1) << bitShiftByOffset), offsets); rowCapacity = newCap; @@ -273,6 +405,11 @@ private void growStringBuffersAndRows(int stringLength) { long currentLength = currentStringByteIndex + stringLength; if (currentLength > data.getLength()) { + if (!this.allowGrowth) { + throw new RapidsHostColumnOverflow ( + "attempt to add string bytes beyond preallocated capacity: " + data.getLength()); + } + long requiredLength = data.getLength(); do { requiredLength = requiredLength * 2; @@ -293,6 +430,10 @@ private void growStructBuffersAndRows() { if (rowCapacity == 0) { rowCapacity = estimatedRows; } else if (rows > rowCapacity) { + if (!this.allowGrowth) { + throw new RapidsHostColumnOverflow ( + "attempt to add row beyond preallocated capacity: " + rowCapacity); + } rowCapacity = Math.min(rowCapacity * 2, Integer.MAX_VALUE - 1); } } @@ -311,6 +452,20 @@ private HostMemoryBuffer copyBuffer(HostMemoryBuffer targetBuffer, HostMemoryBuf return buffer; } + /** + * Reset the validity bit for the given index (used by restore). + * + * @param valid the buffer to reset it in. + * @param index the index to reset it at. + */ + static void resetNullAt(HostMemoryBuffer valid, long index) { + long bucket = index / 8; + byte currentByte = valid.getByte(bucket); + int bitmask = (1 << (index % 8)) & 0x00ff; + currentByte |= bitmask; + valid.setByte(bucket, currentByte); + } + /** * Set the validity bit to null for the given index. * diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/RapidsHostColumnOverflow.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/RapidsHostColumnOverflow.java new file mode 100644 index 00000000000..a9efb6977b4 --- /dev/null +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/RapidsHostColumnOverflow.java @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids; + +public class RapidsHostColumnOverflow extends RuntimeException { + public RapidsHostColumnOverflow() { + super(); + } + + public RapidsHostColumnOverflow(String message) { + super(message); + } +} diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBatchUtils.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBatchUtils.scala index 92588885be0..7474ddfa1a2 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBatchUtils.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBatchUtils.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -67,6 +67,11 @@ object GpuBatchUtils { estimateGpuMemory(field.dataType, field.nullable, rowCount) } + /** Estimate the amount of GPU memory a batch of rows will occupy per column once converted */ + def estimatePerColumnGpuMemory(schema: StructType, rowCount: Long): Array[Long] = { + schema.fields.indices.map(estimateGpuMemory(schema, _, rowCount)).toArray + } + /** * Get the minimum size a column could be that matches these conditions. */ diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuRowToColumnarExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuRowToColumnarExec.scala index 99f17cf341a..038d3ba41cb 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuRowToColumnarExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuRowToColumnarExec.scala @@ -17,8 +17,10 @@ package com.nvidia.spark.rapids import ai.rapids.cudf.{NvtxColor, NvtxRange} +import com.nvidia.spark.rapids.Arm.closeOnExcept import com.nvidia.spark.rapids.Arm.withResource import com.nvidia.spark.rapids.GpuColumnVector.GpuColumnarBatchBuilder +import com.nvidia.spark.rapids.RmmRapidsRetryIterator.{splitTargetSizeInHalfCpu, withRetry} import com.nvidia.spark.rapids.shims.{GpuTypeShims, ShimUnaryExecNode} import org.apache.spark.TaskContext @@ -600,16 +602,32 @@ class RowToColumnarIterator( private var targetRows = 0 private var totalOutputBytes: Long = 0 private var totalOutputRows: Long = 0 + private[this] val pending = new scala.collection.mutable.Queue[InternalRow]() - override def hasNext: Boolean = rowIter.hasNext + override def hasNext: Boolean = pending.nonEmpty || rowIter.hasNext override def next(): ColumnarBatch = { - if (!rowIter.hasNext) { + if (!hasNext) { throw new NoSuchElementException } buildBatch() } + // Attempt to allocate a single host buffer for the full batch of columns, retrying + // with fewer rows if necessary. Then make it spillable. + // Returns of tuple of (actual rows, per-column-sizes, SpillableHostBuffer). + private def allocBufWithRetry(rows : Int) : (Int, Array[Long], SpillableHostBuffer) = { + val targetRowCount = AutoCloseableTargetSize(rows, 1) + withRetry(targetRowCount, splitTargetSizeInHalfCpu) { attempt => + val perColBytes = GpuBatchUtils.estimatePerColumnGpuMemory(localSchema, attempt.targetSize) + closeOnExcept(HostAlloc.alloc(perColBytes.sum, true)) { hBuf => + (attempt.targetSize.toInt, perColBytes, + SpillableHostBuffer(hBuf, hBuf.getLength, SpillPriorities.ACTIVE_ON_DECK_PRIORITY, + RapidsBufferCatalog.singleton)) + } + }.next() + } + private def buildBatch(): ColumnarBatch = { withResource(new NvtxRange("RowToColumnar", NvtxColor.CYAN)) { _ => val streamStart = System.nanoTime() @@ -625,21 +643,50 @@ class RowToColumnarIterator( targetRows = GpuBatchUtils.estimateRowCount(targetSizeBytes, sampleBytes, sampleRows) } } + val (actualRows, perColumnBytes, sBuf) = allocBufWithRetry(targetRows) + targetRows = actualRows - withResource(new GpuColumnarBatchBuilder(localSchema, targetRows)) { builders => + withResource(new GpuColumnarBatchBuilder(localSchema, targetRows, sBuf, + perColumnBytes)) { builders => var rowCount = 0 // Double because validity can be < 1 byte, and this is just an estimate anyways var byteCount: Double = 0 + var overWrite = false // read at least one row - while (rowIter.hasNext && - (rowCount == 0 || rowCount < targetRows && byteCount < targetSizeBytes)) { - val row = rowIter.next() - byteCount += converters.convert(row, builders) - rowCount += 1 + while (!overWrite && hasNext && (rowCount == 0 || + ((rowCount < targetRows) && (byteCount < targetSizeBytes)))) { + val row = if (pending.nonEmpty) { + pending.dequeue() + } else { + rowIter.next() + } + try { + builders.checkpoint() + val rowBytes = converters.convert(row, builders) + byteCount += rowBytes + rowCount += 1 + } catch { + case _ : RapidsHostColumnOverflow => { + // We overwrote the pre-allocated buffers. Restore state and stop here if we can. + builders.restore() + // If this happens on the first row, we aren't going to succeed. If we require + // a single batch, it will fail below. + // For now we will just retry these cases with growth re-enabled - we may run out + // of memory though. + if ((rowCount == 0) || (localGoal.isInstanceOf[RequireSingleBatchLike])) { + builders.setAllowGrowth(true) + } else { + // We wrote some rows, so we can go on to building the batch + overWrite = true + } + pending.enqueue(row) // we need to try this row again + } + case e: Throwable => throw e + } } // enforce RequireSingleBatch limit - if (rowIter.hasNext && localGoal.isInstanceOf[RequireSingleBatchLike]) { + if (hasNext && localGoal.isInstanceOf[RequireSingleBatchLike]) { throw new IllegalStateException("A single batch is required for this operation." + " Please try increasing your partition count.") } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RowToColumnarIteratorRetrySuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RowToColumnarIteratorRetrySuite.scala index 09d3ecf5881..b0eea625f3e 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RowToColumnarIteratorRetrySuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RowToColumnarIteratorRetrySuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.types._ class RowToColumnarIteratorRetrySuite extends RmmSparkRetrySuiteBase { private val schema = StructType(Seq(StructField("a", IntegerType))) - test("test simple OOM retry") { + test("test simple GPU OOM retry") { val rowIter: Iterator[InternalRow] = (1 to 10).map(InternalRow(_)).toIterator val row2ColIter = new RowToColumnarIterator( rowIter, schema, RequireSingleBatch, new GpuRowToColumnConverter(schema)) @@ -35,7 +35,7 @@ class RowToColumnarIteratorRetrySuite extends RmmSparkRetrySuiteBase { } } - test("test simple OOM split and retry") { + test("test simple GPU OOM split and retry") { val rowIter: Iterator[InternalRow] = (1 to 10).map(InternalRow(_)).toIterator val row2ColIter = new RowToColumnarIterator( rowIter, schema, RequireSingleBatch, new GpuRowToColumnConverter(schema)) @@ -45,4 +45,26 @@ class RowToColumnarIteratorRetrySuite extends RmmSparkRetrySuiteBase { row2ColIter.next() } } + + test("test simple CPU OOM retry") { + val rowIter: Iterator[InternalRow] = (1 to 10).map(InternalRow(_)).toIterator + val row2ColIter = new RowToColumnarIterator( + rowIter, schema, RequireSingleBatch, new GpuRowToColumnConverter(schema)) + RmmSpark.forceRetryOOM(RmmSpark.getCurrentThreadId, 1, + RmmSpark.OomInjectionType.CPU.ordinal, 0) + Arm.withResource(row2ColIter.next()) { batch => + assertResult(10)(batch.numRows()) + } + } + + test("test simple CPU OOM split and retry") { + val rowIter: Iterator[InternalRow] = (1 to 10).map(InternalRow(_)).toIterator + val row2ColIter = new RowToColumnarIterator( + rowIter, schema, RequireSingleBatch, new GpuRowToColumnConverter(schema)) + RmmSpark.forceSplitAndRetryOOM(RmmSpark.getCurrentThreadId, 1, + RmmSpark.OomInjectionType.CPU.ordinal, 0) + Arm.withResource(row2ColIter.next()) { batch => + assertResult(10)(batch.numRows()) + } + } }