Skip to content

Commit

Permalink
[SPARK-40932][CORE] Fix issue messages for allGather are overridden
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

The messages returned by allGather may be overridden by the following barrier APIs, eg,

``` scala
      val messages: Array[String] = context.allGather("ABC")
      context.barrier()
```

the  `messages` may be like Array("", ""), but we're expecting Array("ABC", "ABC")

The root cause of this issue is the [messages got by allGather](https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala#L102) pointing to the [original message](https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala#L107) in the local mode. So when the following barrier APIs changed the messages, then the allGather message will be changed accordingly.
Finally, users can't get the correct result.

This PR fixed this issue by sending back the cloned messages.

### Why are the changes needed?

The bug mentioned in this description may block some external SPARK ML libraries which heavily depend on the spark barrier API to do some synchronization. If the barrier mechanism can't guarantee the correctness of the barrier APIs, it will be a disaster for external SPARK ML libraries.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?

I added a unit test, with this PR, the unit test can pass

Closes apache#38410 from wbo4958/allgather-issue.

Authored-by: Bobby Wang <wbo4958@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
wbo4958 authored and cloud-fan committed Oct 28, 2022
1 parent 77694b4 commit 0b892a5
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ private[spark] class BarrierCoordinator(
logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received update from Task " +
s"$taskId, current progress: ${requesters.size}/$numTasks.")
if (requesters.size == numTasks) {
requesters.foreach(_.reply(messages))
requesters.foreach(_.reply(messages.clone()))
// Finished current barrier() call successfully, clean up ContextBarrierState and
// increase the barrier epoch.
logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received all updates from " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -367,4 +367,27 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext with
// double check we kill task success
assert(System.currentTimeMillis() - startTime < 5000)
}

test("SPARK-40932, messages of allGather should not been overridden " +
"by the following barrier APIs") {

sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local[2]"))
sc.setLogLevel("INFO")
val rdd = sc.makeRDD(1 to 10, 2)
val rdd2 = rdd.barrier().mapPartitions { it =>
val context = BarrierTaskContext.get()
// Sleep for a random time before global sync.
Thread.sleep(Random.nextInt(1000))
// Pass partitionId message in
val message: String = context.partitionId().toString
val messages: Array[String] = context.allGather(message)
context.barrier()
Iterator.single(messages.toList)
}
val messages = rdd2.collect()
// All the task partitionIds are shared across all tasks
assert(messages.length === 2)
assert(messages.forall(_ == List("0", "1")))
}

}

0 comments on commit 0b892a5

Please sign in to comment.