Skip to content

Commit

Permalink
[SPARK-12717][PYTHON] Adding thread-safe broadcast pickle registry
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

When using PySpark broadcast variables in a multi-threaded environment,  `SparkContext._pickled_broadcast_vars` becomes a shared resource.  A race condition can occur when broadcast variables that are pickled from one thread get added to the shared ` _pickled_broadcast_vars` and become part of the python command from another thread.  This PR introduces a thread-safe pickled registry using thread local storage so that when python command is pickled (causing the broadcast variable to be pickled and added to the registry) each thread will have their own view of the pickle registry to retrieve and clear the broadcast variables used.

## How was this patch tested?

Added a unit test that causes this race condition using another thread.

Author: Bryan Cutler <cutlerb@gmail.com>

Closes #18695 from BryanCutler/pyspark-bcast-threadsafe-SPARK-12717.
  • Loading branch information
BryanCutler authored and HyukjinKwon committed Aug 1, 2017
1 parent 58da1a2 commit 77cc0d6
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 2 deletions.
19 changes: 19 additions & 0 deletions python/pyspark/broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import sys
import gc
from tempfile import NamedTemporaryFile
import threading

from pyspark.cloudpickle import print_exec
from pyspark.util import _exception_message
Expand Down Expand Up @@ -139,6 +140,24 @@ def __reduce__(self):
return _from_id, (self._jbroadcast.id(),)


class BroadcastPickleRegistry(threading.local):
""" Thread-local registry for broadcast variables that have been pickled
"""

def __init__(self):
self.__dict__.setdefault("_registry", set())

def __iter__(self):
for bcast in self._registry:
yield bcast

def add(self, bcast):
self._registry.add(bcast)

def clear(self):
self._registry.clear()


if __name__ == "__main__":
import doctest
(failure_count, test_count) = doctest.testmod()
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

from pyspark import accumulators
from pyspark.accumulators import Accumulator
from pyspark.broadcast import Broadcast
from pyspark.broadcast import Broadcast, BroadcastPickleRegistry
from pyspark.conf import SparkConf
from pyspark.files import SparkFiles
from pyspark.java_gateway import launch_gateway
Expand Down Expand Up @@ -195,7 +195,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize,
# This allows other code to determine which Broadcast instances have
# been pickled, so it can determine which Java broadcast objects to
# send.
self._pickled_broadcast_vars = set()
self._pickled_broadcast_vars = BroadcastPickleRegistry()

SparkFiles._sc = self
root_dir = SparkFiles.getRootDirectory()
Expand Down
44 changes: 44 additions & 0 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,50 @@ def test_multiple_broadcasts(self):
self.assertEqual(N, size)
self.assertEqual(checksum, csum)

def test_multithread_broadcast_pickle(self):
import threading

b1 = self.sc.broadcast(list(range(3)))
b2 = self.sc.broadcast(list(range(3)))

def f1():
return b1.value

def f2():
return b2.value

funcs_num_pickled = {f1: None, f2: None}

def do_pickle(f, sc):
command = (f, None, sc.serializer, sc.serializer)
ser = CloudPickleSerializer()
ser.dumps(command)

def process_vars(sc):
broadcast_vars = list(sc._pickled_broadcast_vars)
num_pickled = len(broadcast_vars)
sc._pickled_broadcast_vars.clear()
return num_pickled

def run(f, sc):
do_pickle(f, sc)
funcs_num_pickled[f] = process_vars(sc)

# pickle f1, adds b1 to sc._pickled_broadcast_vars in main thread local storage
do_pickle(f1, self.sc)

# run all for f2, should only add/count/clear b2 from worker thread local storage
t = threading.Thread(target=run, args=(f2, self.sc))
t.start()
t.join()

# count number of vars pickled in main thread, only b1 should be counted and cleared
funcs_num_pickled[f1] = process_vars(self.sc)

self.assertEqual(funcs_num_pickled[f1], 1)
self.assertEqual(funcs_num_pickled[f2], 1)
self.assertEqual(len(list(self.sc._pickled_broadcast_vars)), 0)

def test_large_closure(self):
N = 200000
data = [float(i) for i in xrange(N)]
Expand Down

0 comments on commit 77cc0d6

Please sign in to comment.