Skip to content

Commit

Permalink
cleanup, fix hash of string in 3.3+
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Mar 24, 2015
1 parent 8662d5b commit 5707476
Show file tree
Hide file tree
Showing 26 changed files with 156 additions and 161 deletions.
3 changes: 3 additions & 0 deletions bin/spark-submit
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@

SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)"

# disable randomized hash for string in Python 3.3+
export PYTHONHASHSEED=1

# Only define a usage function if an upstream script hasn't done so.
if ! type -t usage >/dev/null 2>&1; then
usage() {
Expand Down
3 changes: 3 additions & 0 deletions bin/spark-submit2.cmd
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ rem
rem This is the entry point for running Spark submit. To avoid polluting the
rem environment, it just launches a new cmd to do the real work.

rem disable randomized hash for string in Python 3.3+
set PYTHONHASHSEED=1

set CLASS=org.apache.spark.deploy.SparkSubmit
call %~dp0spark-class2.cmd %CLASS% %*
set SPARK_ERROR_LEVEL=%ERRORLEVEL%
Expand Down
134 changes: 69 additions & 65 deletions ec2/spark_ec2.py

Large diffs are not rendered by default.

41 changes: 13 additions & 28 deletions python/pyspark/broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,14 @@

import os
import sys
if sys.version < '3':
import cPickle
else:
import pickle as cPickle
basestring = str
unicode = str
import gc
from tempfile import NamedTemporaryFile

if sys.version < '3':
import cPickle as pickle
else:
import pickle
unicode = str

__all__ = ['Broadcast']

Expand Down Expand Up @@ -76,33 +75,19 @@ def __init__(self, sc=None, value=None, pickle_registry=None, path=None):
self._path = path

def dump(self, value, f):
if isinstance(value, (str, bytes, unicode)):
if isinstance(value, unicode):
f.write(b'U')
value = value.encode('utf8')
else:
f.write(b'B')
f.write(value)
else:
f.write(b'P')
cPickle.dump(value, f, 2)
pickle.dump(value, f, 2)
f.close()
return f.name

def load(self, path):
with open(path, 'rb', 1 << 20) as f:
flag = f.read(1)
data = f.read()
if flag == b'P':
# cPickle.loads() may create lots of objects, disable GC
# temporary for better performance
gc.disable()
try:
return cPickle.loads(data)
finally:
gc.enable()
else:
return data.decode('utf8') if flag == b'U' else data
# pickle.load() may create lots of objects, disable GC
# temporary for better performance
gc.disable()
try:
return pickle.load(f)
finally:
gc.enable()

@property
def value(self):
Expand Down
5 changes: 3 additions & 2 deletions python/pyspark/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@
import sys
import re

if sys.version >= '3':
if sys.version > '3':
unicode = str
__doc__ = re.sub(r"(\W|^)[uU](['])", r'\1\2', __doc__)


Expand Down Expand Up @@ -106,7 +107,7 @@ def __init__(self, loadDefaults=True, _jvm=None, _jconf=None):

def set(self, key, value):
"""Set a configuration property."""
self._jconf.set(key, str(value))
self._jconf.set(key, unicode(value))
return self

def setIfMissing(self, key, value):
Expand Down
8 changes: 5 additions & 3 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from __future__ import print_function
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
Expand All @@ -16,6 +15,8 @@
# limitations under the License.
#

from __future__ import print_function

import os
import shutil
import sys
Expand All @@ -38,7 +39,7 @@
from pyspark.status import StatusTracker
from pyspark.profiler import ProfilerCollector, BasicProfiler

if sys.version >= '3':
if sys.version > '3':
xrange = range


Expand Down Expand Up @@ -98,6 +99,7 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
:param profiler_cls: A class of custom Profiler used to do profiling
(default is pyspark.profiler.BasicProfiler).
>>> from pyspark.context import SparkContext
>>> sc = SparkContext('local', 'test')
Expand Down Expand Up @@ -197,7 +199,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize,
# Create a temporary directory inside spark.local.dir:
local_dir = self._jvm.org.apache.spark.util.Utils.getLocalDir(self._jsc.sc().conf())
self._temp_dir = \
self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir, "pyspark")\
self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir, "pyspark") \
.getAbsolutePath()

# profiling stats collected for each PythonRDD
Expand Down
1 change: 0 additions & 1 deletion python/pyspark/daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def worker(sock):
# otherwise writes also cause a seek that makes us miss data on the read side.
infile = os.fdopen(os.dup(sock.fileno()), "rb", 65536)
outfile = os.fdopen(os.dup(sock.fileno()), "wb", 65536)

exit_code = 0
try:
worker_main(infile, outfile)
Expand Down
1 change: 0 additions & 1 deletion python/pyspark/heapq3.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,6 @@
without surprises: heap[0] is the smallest item, and heap.sort()
maintains the heap invariant!
"""
from __future__ import print_function

# Original code by Kevin O'Connor, augmented by Tim Peters and Raymond Hettinger

Expand Down
8 changes: 5 additions & 3 deletions python/pyspark/mllib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@

import sys
from . import rand as random
random.__name__ = 'random'
random.RandomRDDs.__module__ = __name__ + '.random'
sys.modules[__name__ + '.random'] = random
modname = __name__ + '.random'
random.__name__ = modname
random.RandomRDDs.__module__ = modname
sys.modules[modname] = random
del modname, sys
8 changes: 4 additions & 4 deletions python/pyspark/mllib/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import sys
if sys.version >= '3':
long = int
basestring = str
unicode = str

import py4j.protocol
from py4j.protocol import Py4JJavaError
Expand Down Expand Up @@ -79,11 +79,11 @@ def _py2java(sc, obj):
obj = ListConverter().convert([_py2java(sc, x) for x in obj], sc._gateway._gateway_client)
elif isinstance(obj, JavaObject):
pass
elif isinstance(obj, (int, long, float, bool, basestring)):
elif isinstance(obj, (int, long, float, bool, bytes, str, unicode)):
pass
else:
bytes = bytearray(PickleSerializer().dumps(obj))
obj = sc._jvm.SerDe.loads(bytes)
data = bytearray(PickleSerializer().dumps(obj))
obj = sc._jvm.SerDe.loads(data)
return obj


Expand Down
11 changes: 8 additions & 3 deletions python/pyspark/mllib/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@
import sys
import warnings
import random
import binascii
if sys.version >= '3':
basestring = str
unicode = str

from py4j.protocol import Py4JJavaError

Expand Down Expand Up @@ -192,8 +194,8 @@ class HashingTF(object):
>>> htf = HashingTF(100)
>>> doc = u"a a b b c d".split(u" ")
>>> # htf.transform(doc)
# SparseVector(100, {1: 1.0, 14: 1.0, 31: 2.0, 44: 2.0})
>>> htf.transform(doc)
SparseVector(100, {55: 1.0, 59: 2.0, 81: 2.0, 88: 1.0})
"""
def __init__(self, numFeatures=1 << 20):
"""
Expand All @@ -203,7 +205,10 @@ def __init__(self, numFeatures=1 << 20):

def indexOf(self, term):
""" Returns the index of the input term. """
return hash(term) % self.numFeatures
# hash of string is not portable in Python 3
if isinstance(term, unicode):
term = term.encode('utf-8')
return (binascii.crc32(term) & 0x7FFFFFFF) % self.numFeatures

def transform(self, document):
"""
Expand Down
8 changes: 3 additions & 5 deletions python/pyspark/mllib/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,16 @@
object from MLlib or pass SciPy C{scipy.sparse} column vectors if
SciPy is available in their environment.
"""
from __future__ import print_function

import sys
import array
try:
import copy_reg
except ImportError:
import copyreg as copy_reg

if sys.version >= '3':
basestring = str
xrange = range
import copyreg as copy_reg
else:
import copy_reg

import numpy as np

Expand Down
1 change: 0 additions & 1 deletion python/pyspark/mllib/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
"""
Fuller unit tests for Python MLlib.
"""
from __future__ import print_function

import os
import sys
Expand Down
6 changes: 2 additions & 4 deletions python/pyspark/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
# limitations under the License.
#

from __future__ import print_function

import cProfile
import pstats
import os
Expand Down Expand Up @@ -86,11 +84,11 @@ class Profiler(object):
>>> from pyspark import BasicProfiler
>>> class MyCustomProfiler(BasicProfiler):
... def show(self, id):
... print "My custom profiles for RDD:%s" % id
... print("My custom profiles for RDD:%s" % id)
...
>>> conf = SparkConf().set("spark.python.profile", "true")
>>> sc = SparkContext('local', 'test', conf=conf, profiler_cls=MyCustomProfiler)
>>> sc.parallelize(list(range(1000))).map(lambda x: 2 * x).take(10)
>>> sc.parallelize(range(1000)).map(lambda x: 2 * x).take(10)
[0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
>>> sc.show_profiles()
My custom profiles for RDD:1
Expand Down
19 changes: 8 additions & 11 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,9 @@
from py4j.java_collections import ListConverter, MapConverter



__all__ = ["RDD"]


# TODO: for Python 3.3+, PYTHONHASHSEED should be reset to disable randomized
# hash for string
def portable_hash(x):
"""
This function returns consistant hash code for builtin types, especially
Expand Down Expand Up @@ -355,7 +352,7 @@ def distinct(self, numPartitions=None):
"""
return self.map(lambda x: (x, None)) \
.reduceByKey(lambda x, _: x, numPartitions) \
.map(lambda x__: x__[0])
.map(lambda x: x[0])

def sample(self, withReplacement, fraction, seed=None):
"""
Expand Down Expand Up @@ -595,7 +592,7 @@ def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x):

def sortPartition(iterator):
sort = ExternalSorter(memory * 0.9, serializer).sorted if spill else sorted
return iter(sort(iterator, key=lambda k_v1: keyfunc(k_v1[0]), reverse=(not ascending)))
return iter(sort(iterator, key=lambda kv: keyfunc(kv[0]), reverse=(not ascending)))

if numPartitions == 1:
if self.getNumPartitions() > 1:
Expand All @@ -610,7 +607,7 @@ def sortPartition(iterator):
return self # empty RDD
maxSampleSize = numPartitions * 20.0 # constant from Spark's RangePartitioner
fraction = min(maxSampleSize / max(rddSize, 1), 1.0)
samples = self.sample(False, fraction, 1).map(lambda k_v2: k_v2[0]).collect()
samples = self.sample(False, fraction, 1).map(lambda kv: kv[0]).collect()
samples = sorted(samples, reverse=(not ascending), key=keyfunc)

# we have numPartitions many parts but one of the them has
Expand Down Expand Up @@ -1479,7 +1476,7 @@ def keys(self):
>>> m.collect()
[1, 3]
"""
return self.map(lambda k_v3: k_v3[0])
return self.map(lambda x: x[0])

def values(self):
"""
Expand All @@ -1489,7 +1486,7 @@ def values(self):
>>> m.collect()
[2, 4]
"""
return self.map(lambda k_v4: k_v4[1])
return self.map(lambda x: x[1])

def reduceByKey(self, func, numPartitions=None):
"""
Expand Down Expand Up @@ -1816,7 +1813,7 @@ def flatMapValues(self, f):
>>> x.flatMapValues(f).collect()
[('a', 'x'), ('a', 'y'), ('a', 'z'), ('b', 'p'), ('b', 'r')]
"""
flat_map_fn = lambda k_v6: ((k_v6[0], x) for x in f(k_v6[1]))
flat_map_fn = lambda kv: ((kv[0], x) for x in f(kv[1]))
return self.flatMap(flat_map_fn, preservesPartitioning=True)

def mapValues(self, f):
Expand All @@ -1830,7 +1827,7 @@ def mapValues(self, f):
>>> x.mapValues(f).collect()
[('a', 3), ('b', 1)]
"""
map_values_fn = lambda k_v7: (k_v7[0], f(k_v7[1]))
map_values_fn = lambda kv: (kv[0], f(kv[1]))
return self.map(map_values_fn, preservesPartitioning=True)

def groupWith(self, other, *others):
Expand Down Expand Up @@ -2119,7 +2116,7 @@ def lookup(self, key):
>>> sorted.lookup(1024)
[]
"""
values = self.filter(lambda k_v5: k_v5[0] == key).values()
values = self.filter(lambda kv: kv[0] == key).values()

if self.partitioner is not None:
return self.ctx.runJob(values, lambda x: x, [self.partitioner(key)], False)
Expand Down
12 changes: 7 additions & 5 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,11 @@
import collections
import zlib
import itertools
try:

if sys.version < '3':
import cPickle as pickle
protocol = 2
except ImportError:
else:
import pickle
protocol = 3

Expand Down Expand Up @@ -387,10 +388,11 @@ class PickleSerializer(FramedSerializer):
def dumps(self, obj):
return pickle.dumps(obj, protocol)

def loads(self, obj):
if sys.version >= '3':
if sys.version >= '3':
def loads(self, obj):
return pickle.loads(obj, encoding='bytes')
else:
else:
def loads(self, obj):
return pickle.loads(obj)


Expand Down
Loading

0 comments on commit 5707476

Please sign in to comment.