Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX check the type of the entries in sys.modules #326

Merged
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion cloudpickle/cloudpickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,13 @@ def _whichmodule(obj, name):
# modules that trigger imports of other modules upon calls to getattr or
# other threads importing at the same time.
for module_name, module in sys.modules.copy().items():
if module_name == '__main__' or module is None:
# Some modules such as coverage can inject non-module objects inside
# sys.modules
if (
module_name == '__main__' or
module is None or
not isinstance(module, types.ModuleType)
pierreglaser marked this conversation as resolved.
Show resolved Hide resolved
):
continue
try:
if _getattribute(module, name)[0] is obj:
Expand Down
94 changes: 72 additions & 22 deletions tests/cloudpickle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
import cloudpickle
from cloudpickle.cloudpickle import _is_dynamic
from cloudpickle.cloudpickle import _make_empty_cell, cell_set
from cloudpickle.cloudpickle import _extract_class_dict
from cloudpickle.cloudpickle import _extract_class_dict, _whichmodule

from .testutils import subprocess_pickle_echo
from .testutils import assert_run_python_script
Expand Down Expand Up @@ -1048,35 +1048,85 @@ def __init__(self, x):

self.assertEqual(set(weakset), {depickled1, depickled2})

def test_non_module_object_passing_whichmodule_test(self):
# https://github.com/cloudpipe/cloudpickle/pull/326: cloudpickle should
# not try to instrospect non-modules object when trying to discover the
# module of a function/class
def func(x):
return x ** 2

func.__module__ = None
pierreglaser marked this conversation as resolved.
Show resolved Hide resolved

class NonModuleObject(object):
def __getattr__(self, name):
# We whitelist func so that a _whichmodule(func, None) call returns
# the NonModuleObject instance if a type check on the entries
# of sys.modules is not carried out, but manipulating this
# instance thinking it really is a module later on in the
# pickling process of func errors out
if name == 'func':
return func
else:
raise ValueError
pierreglaser marked this conversation as resolved.
Show resolved Hide resolved

non_module_object = NonModuleObject()

assert func(2) == 4
assert func is non_module_object.func

# Any manipulation of non_module_object relying on attribute access
# will raise an Exception
with pytest.raises(ValueError):
pierreglaser marked this conversation as resolved.
Show resolved Hide resolved
_is_dynamic(non_module_object)

try:
sys.modules['NonModuleObject'] = non_module_object

func_module_name = _whichmodule(func, None)
assert func_module_name != 'NonModuleObject'
assert func_module_name is None

depickled_func = pickle_depickle(func, protocol=self.protocol)
assert depickled_func(2) == 4

finally:
sys.modules.pop('NonModuleObject')

def test_faulty_module(self):
pierreglaser marked this conversation as resolved.
Show resolved Hide resolved
for module_name in ['_missing_module', None]:
class FaultyModule(object):
def __getattr__(self, name):
# This throws an exception while looking up within
# pickle.whichmodule or getattr(module, name, None)
raise Exception()
for base_class in (object, types.ModuleType):
for module_name in ['_missing_module', None]:
class FaultyModule(base_class):
def __getattr__(self, name):
# This throws an exception while looking up within
# pickle.whichmodule or getattr(module, name, None)
raise Exception()

class Foo(object):
__module__ = module_name

class Foo(object):
__module__ = module_name
def foo(self):
return "it works!"

def foo(self):
def foo():
return "it works!"

def foo():
return "it works!"
foo.__module__ = module_name

foo.__module__ = module_name
if base_class is types.ModuleType: # noqa
faulty_module = FaultyModule('_faulty_module')
else:
faulty_module = FaultyModule()
sys.modules["_faulty_module"] = faulty_module

sys.modules["_faulty_module"] = FaultyModule()
try:
# Test whichmodule in save_global.
self.assertEqual(pickle_depickle(Foo()).foo(), "it works!")
try:
# Test whichmodule in save_global.
self.assertEqual(pickle_depickle(Foo()).foo(), "it works!")

# Test whichmodule in save_function.
cloned = pickle_depickle(foo, protocol=self.protocol)
self.assertEqual(cloned(), "it works!")
finally:
sys.modules.pop("_faulty_module", None)
# Test whichmodule in save_function.
cloned = pickle_depickle(foo, protocol=self.protocol)
self.assertEqual(cloned(), "it works!")
finally:
sys.modules.pop("_faulty_module", None)

def test_dynamic_pytest_module(self):
# Test case for pull request https://github.com/cloudpipe/cloudpickle/pull/116
Expand Down