Skip to content

Commit

Permalink
pythongh-122311: Improve and unify pickle errors
Browse files Browse the repository at this point in the history
* Raise PicklingError instead of UnicodeEncodeError, ValueError
  and AttributeError in both implementations.
* Chain the original exception to the pickle-specific one as __context__.
* Include the error message of ImportError and some AttributeError in
  the PicklingError error message.
* Unify error messages between Python and C implementations.
* Refer to documented __reduce__ and __newobj__ callables instead of
  internal methods (e.g. save_reduce()) or pickle opcodes (e.g. NEWOBJ).
* Include more details in error messages (what expected, what got).
* Avoid including a potentially long repr of an arbitrary object in
  error messages.
  • Loading branch information
serhiy-storchaka committed Aug 7, 2024
1 parent 9e551f9 commit 554abbc
Show file tree
Hide file tree
Showing 4 changed files with 254 additions and 223 deletions.
90 changes: 49 additions & 41 deletions Lib/pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,9 @@ def whichmodule(obj, name):
"""Find the module an object belong to."""
dotted_path = name.split('.')
module_name = getattr(obj, '__module__', None)
if module_name is None and '<locals>' not in dotted_path:
if '<locals>' in dotted_path:
raise PicklingError(f"Can't pickle local object {obj!r}")
if module_name is None:
# Protect the iteration by using a list copy of sys.modules against dynamic
# modules that trigger imports of other modules upon calls to getattr.
for module_name, module in sys.modules.copy().items():
Expand All @@ -336,22 +338,21 @@ def whichmodule(obj, name):
except AttributeError:
pass
module_name = '__main__'
elif module_name is None:
module_name = '__main__'

try:
__import__(module_name, level=0)
module = sys.modules[module_name]
except (ImportError, ValueError, KeyError) as exc:
raise PicklingError(f"Can't pickle {obj!r}: {exc!s}")
try:
if _getattribute(module, dotted_path) is obj:
return module_name
except (ImportError, KeyError, AttributeError):
raise PicklingError(
"Can't pickle %r: it's not found as %s.%s" %
(obj, module_name, name)) from None
except AttributeError:
raise PicklingError(f"Can't pickle {obj!r}: "
f"it's not found as {module_name}.{name}")

raise PicklingError(
"Can't pickle %r: it's not the same object as %s.%s" %
(obj, module_name, name))
f"Can't pickle {obj!r}: it's not the same object as {module_name}.{name}")

def encode_long(x):
r"""Encode a long to a two's complement little-endian binary string.
Expand Down Expand Up @@ -403,6 +404,13 @@ def decode_long(data):
"""
return int.from_bytes(data, byteorder='little', signed=True)

def _T(obj):
cls = type(obj)
module = cls.__module__
if module in (None, 'builtins', '__main__'):
return cls.__qualname__
return f'{module}.{cls.__qualname__}'


_NoValue = object()

Expand Down Expand Up @@ -585,8 +593,7 @@ def save(self, obj, save_persistent_id=True):
if reduce is not _NoValue:
rv = reduce()
else:
raise PicklingError("Can't pickle %r object: %r" %
(t.__name__, obj))
raise PicklingError(f"Can't pickle {_T(t)} object")

# Check for string returned by reduce(), meaning "save as global"
if isinstance(rv, str):
Expand All @@ -595,13 +602,13 @@ def save(self, obj, save_persistent_id=True):

# Assert that reduce() returned a tuple
if not isinstance(rv, tuple):
raise PicklingError("%s must return string or tuple" % reduce)
raise PicklingError(f'__reduce__ must return a string or tuple, not {_T(rv)}')

# Assert that it returned an appropriately sized tuple
l = len(rv)
if not (2 <= l <= 6):
raise PicklingError("Tuple returned by %s must have "
"two to six elements" % reduce)
raise PicklingError("tuple returned by __reduce__ "
"must contain 2 through 6 elements")

# Save the reduce() output and finally memoize the object
self.save_reduce(obj=obj, *rv)
Expand All @@ -626,10 +633,12 @@ def save_reduce(self, func, args, state=None, listitems=None,
dictitems=None, state_setter=None, *, obj=None):
# This API is called by some subclasses

if not isinstance(args, tuple):
raise PicklingError("args from save_reduce() must be a tuple")
if not callable(func):
raise PicklingError("func from save_reduce() must be callable")
raise PicklingError(f"first item of the tuple returned by __reduce__ "
f"must be callable, not {_T(func)}")
if not isinstance(args, tuple):
raise PicklingError(f"second item of the tuple returned by __reduce__ "
f"must be a tuple, not {_T(args)}")

save = self.save
write = self.write
Expand All @@ -638,11 +647,10 @@ def save_reduce(self, func, args, state=None, listitems=None,
if self.proto >= 2 and func_name == "__newobj_ex__":
cls, args, kwargs = args
if not hasattr(cls, "__new__"):
raise PicklingError("args[0] from {} args has no __new__"
.format(func_name))
raise PicklingError("first argument to __newobj_ex__() has no __new__")
if obj is not None and cls is not obj.__class__:
raise PicklingError("args[0] from {} args has the wrong class"
.format(func_name))
raise PicklingError(f"first argument to __newobj_ex__() "
f"must be {obj.__class__!r}, not {cls!r}")
if self.proto >= 4:
save(cls)
save(args)
Expand Down Expand Up @@ -682,11 +690,10 @@ def save_reduce(self, func, args, state=None, listitems=None,
# Python 2.2).
cls = args[0]
if not hasattr(cls, "__new__"):
raise PicklingError(
"args[0] from __newobj__ args has no __new__")
raise PicklingError("first argument to __newobj__() has no __new__")
if obj is not None and cls is not obj.__class__:
raise PicklingError(
"args[0] from __newobj__ args has the wrong class")
raise PicklingError(f"first argument to __newobj__() "
f"must be {obj.__class__!r}, not {cls!r}")
args = args[1:]
save(cls)
save(args)
Expand Down Expand Up @@ -1128,8 +1135,7 @@ def save_global(self, obj, name=None):
def _save_toplevel_by_name(self, module_name, name):
if self.proto >= 3:
# Non-ASCII identifiers are supported only with protocols >= 3.
self.write(GLOBAL + bytes(module_name, "utf-8") + b'\n' +
bytes(name, "utf-8") + b'\n')
encoding = "utf-8"
else:
if self.fix_imports:
r_name_mapping = _compat_pickle.REVERSE_NAME_MAPPING
Expand All @@ -1138,13 +1144,19 @@ def _save_toplevel_by_name(self, module_name, name):
module_name, name = r_name_mapping[(module_name, name)]
elif module_name in r_import_mapping:
module_name = r_import_mapping[module_name]
try:
self.write(GLOBAL + bytes(module_name, "ascii") + b'\n' +
bytes(name, "ascii") + b'\n')
except UnicodeEncodeError:
raise PicklingError(
"can't pickle global identifier '%s.%s' using "
"pickle protocol %i" % (module_name, name, self.proto)) from None
encoding = "ascii"
try:
self.write(GLOBAL + bytes(module_name, encoding) + b'\n')
except UnicodeEncodeError:
raise PicklingError(
f"can't pickle module identifier {module_name!r} using "
f"pickle protocol {self.proto}")
try:
self.write(bytes(name, encoding) + b'\n')
except UnicodeEncodeError:
raise PicklingError(
f"can't pickle global identifier {name!r} using "
f"pickle protocol {self.proto}")

def save_type(self, obj):
if obj is type(None):
Expand Down Expand Up @@ -1605,17 +1617,13 @@ def find_class(self, module, name):
elif module in _compat_pickle.IMPORT_MAPPING:
module = _compat_pickle.IMPORT_MAPPING[module]
__import__(module, level=0)
if self.proto >= 4:
module = sys.modules[module]
if self.proto >= 4 and '.' in name:
dotted_path = name.split('.')
if '<locals>' in dotted_path:
raise AttributeError(
f"Can't get local attribute {name!r} on {module!r}")
try:
return _getattribute(module, dotted_path)
return _getattribute(sys.modules[module], dotted_path)
except AttributeError:
raise AttributeError(
f"Can't get attribute {name!r} on {module!r}") from None
f"Can't resolve path {name!r} on module {module!r}")
else:
return getattr(sys.modules[module], name)

Expand Down
Loading

0 comments on commit 554abbc

Please sign in to comment.