Skip to content

Commit

Permalink
[SPARK-10542] [PYSPARK] fix serialize namedtuple
Browse files Browse the repository at this point in the history
Author: Davies Liu <davies@databricks.com>

Closes #8707 from davies/fix_namedtuple.
  • Loading branch information
Davies Liu authored and davies committed Sep 15, 2015
1 parent 1a09552 commit 5520418
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 1 deletion.
15 changes: 14 additions & 1 deletion python/pyspark/cloudpickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,11 @@ def save_global(self, obj, name=None, pack=struct.pack):
if new_override:
d['__new__'] = obj.__new__

# workaround for namedtuple (hijacked by PySpark)
if getattr(obj, '_is_namedtuple_', False):
self.save_reduce(_load_namedtuple, (obj.__name__, obj._fields))
return

self.save(_load_class)
self.save_reduce(typ, (obj.__name__, obj.__bases__, {"__doc__": obj.__doc__}), obj=obj)
d.pop('__doc__', None)
Expand Down Expand Up @@ -382,7 +387,7 @@ def save_instancemethod(self, obj):
self.save_reduce(types.MethodType, (obj.__func__, obj.__self__), obj=obj)
else:
self.save_reduce(types.MethodType, (obj.__func__, obj.__self__, obj.__self__.__class__),
obj=obj)
obj=obj)
dispatch[types.MethodType] = save_instancemethod

def save_inst(self, obj):
Expand Down Expand Up @@ -744,6 +749,14 @@ def _load_class(cls, d):
return cls


def _load_namedtuple(name, fields):
"""
Loads a class generated by namedtuple
"""
from collections import namedtuple
return namedtuple(name, fields)


"""Constructors for 3rd party libraries
Note: These can never be renamed due to client compatibility issues"""

Expand Down
1 change: 1 addition & 0 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ def _hack_namedtuple(cls):
def __reduce__(self):
return (_restore, (name, fields, tuple(self)))
cls.__reduce__ = __reduce__
cls._is_namedtuple_ = True
return cls


Expand Down
5 changes: 5 additions & 0 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,11 @@ def test_namedtuple(self):
p2 = loads(dumps(p1, 2))
self.assertEqual(p1, p2)

from pyspark.cloudpickle import dumps
P2 = loads(dumps(P))
p3 = P2(1, 3)
self.assertEqual(p1, p3)

def test_itemgetter(self):
from operator import itemgetter
ser = CloudPickleSerializer()
Expand Down

0 comments on commit 5520418

Please sign in to comment.