diff --git a/integration_tests/src/main/python/udf_cudf_test.py b/integration_tests/src/main/python/udf_cudf_test.py index c97ba94df1c..2e6093ea734 100644 --- a/integration_tests/src/main/python/udf_cudf_test.py +++ b/integration_tests/src/main/python/udf_cudf_test.py @@ -36,10 +36,13 @@ 'spark.rapids.sql.python.gpu.enabled': 'true' } +small_data = [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)] -def _create_df(spark): - elements = list(map(lambda i: (i, i/1.0), range(1, 5000))) - return spark.createDataFrame(elements * 2, ("id", "v")) +large_data = list(map(lambda i: (i, i/1.0), range(1, 5000))) * 2 + + +def _create_df(spark, data=large_data): + return spark.createDataFrame(data, ("id", "v")) # since this test requires to run different functions on CPU and GPU(need cudf), @@ -76,13 +79,14 @@ def _plus_one_gpu_func(v: pd.Series) -> pd.Series: @cudf_udf -def test_with_column(enable_cudf_udf): +@pytest.mark.parametrize('data', [small_data, large_data], ids=['small data', 'large data']) +def test_with_column(enable_cudf_udf, data): def cpu_run(spark): - df = _create_df(spark) + df = _create_df(spark, data) return df.withColumn("v1", _plus_one_cpu_func(df.v)).collect() def gpu_run(spark): - df = _create_df(spark) + df = _create_df(spark, data) return df.withColumn("v1", _plus_one_gpu_func(df.v)).collect() _assert_cpu_gpu(cpu_run, gpu_run, gpu_conf=_conf) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuArrowEvalPythonExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuArrowEvalPythonExec.scala index 74dd87160b9..4995d8491aa 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuArrowEvalPythonExec.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuArrowEvalPythonExec.scala @@ -187,6 +187,14 @@ class BatchQueue extends AutoCloseable with Arm { } } + def finish(): Unit = synchronized { + if (!isSet) { + // Wake up anyone waiting for the first batch. + isSet = true + notifyAll() + } + } + def remove(): ColumnarBatch = synchronized { if (queue.isEmpty) { null @@ -369,7 +377,8 @@ class GpuArrowPythonRunner( schema: StructType, timeZoneId: String, conf: Map[String, String], - batchSize: Long) + batchSize: Long, + onDataWriteFinished: () => Unit) extends BasePythonRunner[ColumnarBatch, ColumnarBatch](funcs, evalType, argOffsets) with GpuPythonArrowOutput { @@ -431,6 +440,7 @@ class GpuArrowPythonRunner( } { writer.close() dataOut.flush() + if (onDataWriteFinished != null) onDataWriteFinished() } } } @@ -587,7 +597,8 @@ case class GpuArrowEvalPythonExec( schema, sessionLocalTimeZone, pythonRunnerConf, - batchSize){ + batchSize, + () => queue.finish()){ override def minReadTargetBatchSize: Int = targetReadBatchSize }.compute(projectedIterator, context.partitionId(),