Skip to content

Commit

Permalink
determine spec shape only at mock construction time
Browse files Browse the repository at this point in the history
  • Loading branch information
carljm committed Dec 16, 2022
1 parent 1b6241a commit 6595272
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 7 deletions.
8 changes: 7 additions & 1 deletion Lib/test/test_unittest/testmock/testasync.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,9 +303,15 @@ def test_spec_normal_methods_on_class_with_mock(self):
def test_spec_async_attributes_instance(self):
async_instance = AsyncClass()
async_instance.async_func_attr = async_func
async_instance.later_async_func_attr = normal_func

mock_async_instance = Mock(spec_set=async_instance)

async_instance.later_async_func_attr = async_func

mock_async_instance = Mock(async_instance)
self.assertIsInstance(mock_async_instance.async_func_attr, AsyncMock)
# only the shape of the spec at the time of mock construction matters
self.assertNotIsInstance(mock_async_instance.later_async_func_attr, AsyncMock)

def test_spec_mock_type_kw(self):
def inner_test(mock_type):
Expand Down
16 changes: 10 additions & 6 deletions Lib/unittest/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,10 +506,9 @@ def _mock_add_spec(self, spec, spec_set, _spec_as_instance=False,

_spec_class = None
_spec_signature = None
_spec_obj = None
_spec_asyncs = []

if spec is not None and not _is_list(spec):
_spec_obj = spec
if isinstance(spec, type):
_spec_class = spec
else:
Expand All @@ -518,14 +517,20 @@ def _mock_add_spec(self, spec, spec_set, _spec_as_instance=False,
_spec_as_instance, _eat_self)
_spec_signature = res and res[1]

spec = dir(spec)
spec_list = dir(spec)

for attr in spec_list:
if iscoroutinefunction(getattr(spec, attr, None)):
_spec_asyncs.append(attr)

spec = spec_list

__dict__ = self.__dict__
__dict__['_spec_class'] = _spec_class
__dict__['_spec_obj'] = _spec_obj
__dict__['_spec_set'] = spec_set
__dict__['_spec_signature'] = _spec_signature
__dict__['_mock_methods'] = spec
__dict__['_spec_asyncs'] = _spec_asyncs

def __get_return_value(self):
ret = self._mock_return_value
Expand Down Expand Up @@ -1015,8 +1020,7 @@ def _get_child_mock(self, /, **kw):
For non-callable mocks the callable variant will be used (rather than
any custom subclass)."""
_new_name = kw.get("_new_name")
_spec_val = getattr(self.__dict__["_spec_obj"], _new_name, None)
if _spec_val is not None and asyncio.iscoroutinefunction(_spec_val):
if _new_name in self.__dict__['_spec_asyncs']:
return AsyncMock(**kw)

if self._mock_sealed:
Expand Down

0 comments on commit 6595272

Please sign in to comment.