Skip to content

Commit

Permalink
Put the GPU data back on host before processing cache on CPU (#2970)
Browse files Browse the repository at this point in the history
* make sure the gpu batch gets copied to host before accessing it

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* addressed review changes

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

Co-authored-by: Raza Jafri <rjafri@nvidia.com>
  • Loading branch information
razajafri and razajafri authored Jul 22, 2021
1 parent 283d87c commit 8cdf2d9
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 1 deletion.
14 changes: 14 additions & 0 deletions integration_tests/src/main/python/cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,20 @@ def helper(spark):
# NOTE: we aren't comparing cpu and gpu results, we are comparing the cached and non-cached results.
assert_equal(reg_result, cached_result)


@pytest.mark.parametrize('enable_vectorized', enable_vectorized_confs, ids=idfn)
def test_cache_array(enable_vectorized):
def helper(spark):
data = [("aaa", "123 456 789"), ("bbb", "444 555 666"), ("ccc", "777 888 999")]
columns = ["a","b"]
df = spark.createDataFrame(data).toDF(*columns)
newdf = df.withColumn('newb', f.split(f.col('b'),' '))
newdf.persist()
return newdf.count()

with_gpu_session(helper, conf = enable_vectorized)


def function_to_test_on_cached_df(with_x_session, func, data_gen, test_conf):
def with_cache(cached):
def helper(spark):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1289,7 +1289,27 @@ class ParquetCachedBatchSerializer extends CachedBatchSerializer with Arm {
*/
class ColumnarBatchToCachedBatchIterator extends InternalRowToCachedBatchIterator {
override def getIterator: Iterator[InternalRow] = {
iter.asInstanceOf[Iterator[ColumnarBatch]].next.rowIterator().asScala

new Iterator[InternalRow] {
// We have to check for null context because of the unit test
Option(TaskContext.get).foreach(_.addTaskCompletionListener[Unit](_ => hostBatch.close()))

val batch: ColumnarBatch = iter.asInstanceOf[Iterator[ColumnarBatch]].next
val hostBatch = if (batch.column(0).isInstanceOf[GpuColumnVector]) {
withResource(batch) { batch =>
new ColumnarBatch(batch.safeMap(_.copyToHost()).toArray, batch.numRows())
}
} else {
batch
}

val rowIterator = hostBatch.rowIterator().asScala

override def next: InternalRow = rowIterator.next

override def hasNext: Boolean = rowIterator.hasNext

}
}
}

Expand Down

0 comments on commit 8cdf2d9

Please sign in to comment.