diff --git a/Doc/library/unittest.mock.rst b/Doc/library/unittest.mock.rst index ed00ee6d0c2d83..21e4709f81609e 100644 --- a/Doc/library/unittest.mock.rst +++ b/Doc/library/unittest.mock.rst @@ -201,9 +201,11 @@ The Mock Class .. testsetup:: + import asyncio + import inspect import unittest from unittest.mock import sentinel, DEFAULT, ANY - from unittest.mock import patch, call, Mock, MagicMock, PropertyMock + from unittest.mock import patch, call, Mock, MagicMock, PropertyMock, AsyncMock from unittest.mock import mock_open :class:`Mock` is a flexible mock object intended to replace the use of stubs and @@ -851,6 +853,217 @@ object:: >>> p.assert_called_once_with() +.. class:: AsyncMock(spec=None, side_effect=None, return_value=DEFAULT, wraps=None, name=None, spec_set=None, unsafe=False, **kwargs) + + An asynchronous version of :class:`Mock`. The :class:`AsyncMock` object will + behave so the object is recognized as an async function, and the result of a + call is an awaitable. + + >>> mock = AsyncMock() + >>> asyncio.iscoroutinefunction(mock) + True + >>> inspect.isawaitable(mock()) + True + + The result of ``mock()`` is an async function which will have the outcome + of ``side_effect`` or ``return_value``: + + - if ``side_effect`` is a function, the async function will return the + result of that function, + - if ``side_effect`` is an exception, the async function will raise the + exception, + - if ``side_effect`` is an iterable, the async function will return the + next value of the iterable, however, if the sequence of result is + exhausted, ``StopIteration`` is raised immediately, + - if ``side_effect`` is not defined, the async function will return the + value defined by ``return_value``, hence, by default, the async function + returns a new :class:`AsyncMock` object. + + + Setting the *spec* of a :class:`Mock` or :class:`MagicMock` to an async function + will result in a coroutine object being returned after calling. + + >>> async def async_func(): pass + ... + >>> mock = MagicMock(async_func) + >>> mock + + >>> mock() + + + .. method:: assert_awaited() + + Assert that the mock was awaited at least once. + + >>> mock = AsyncMock() + >>> async def main(): + ... await mock() + ... + >>> asyncio.run(main()) + >>> mock.assert_awaited() + >>> mock_2 = AsyncMock() + >>> mock_2.assert_awaited() + Traceback (most recent call last): + ... + AssertionError: Expected mock to have been awaited. + + .. method:: assert_awaited_once() + + Assert that the mock was awaited exactly once. + + >>> mock = AsyncMock() + >>> async def main(): + ... await mock() + ... + >>> asyncio.run(main()) + >>> mock.assert_awaited_once() + >>> asyncio.run(main()) + >>> mock.method.assert_awaited_once() + Traceback (most recent call last): + ... + AssertionError: Expected mock to have been awaited once. Awaited 2 times. + + .. method:: assert_awaited_with(*args, **kwargs) + + Assert that the last await was with the specified arguments. + + >>> mock = AsyncMock() + >>> async def main(*args, **kwargs): + ... await mock(*args, **kwargs) + ... + >>> asyncio.run(main('foo', bar='bar')) + >>> mock.assert_awaited_with('foo', bar='bar') + >>> mock.assert_awaited_with('other') + Traceback (most recent call last): + ... + AssertionError: expected call not found. + Expected: mock('other') + Actual: mock('foo', bar='bar') + + .. method:: assert_awaited_once_with(*args, **kwargs) + + Assert that the mock was awaited exactly once and with the specified + arguments. + + >>> mock = AsyncMock() + >>> async def main(*args, **kwargs): + ... await mock(*args, **kwargs) + ... + >>> asyncio.run(main('foo', bar='bar')) + >>> mock.assert_awaited_once_with('foo', bar='bar') + >>> asyncio.run(main('foo', bar='bar')) + >>> mock.assert_awaited_once_with('foo', bar='bar') + Traceback (most recent call last): + ... + AssertionError: Expected mock to have been awaited once. Awaited 2 times. + + .. method:: assert_any_await(*args, **kwargs) + + Assert the mock has ever been awaited with the specified arguments. + + >>> mock = AsyncMock() + >>> async def main(*args, **kwargs): + ... await mock(*args, **kwargs) + ... + >>> asyncio.run(main('foo', bar='bar')) + >>> asyncio.run(main('hello')) + >>> mock.assert_any_await('foo', bar='bar') + >>> mock.assert_any_await('other') + Traceback (most recent call last): + ... + AssertionError: mock('other') await not found + + .. method:: assert_has_awaits(calls, any_order=False) + + Assert the mock has been awaited with the specified calls. + The :attr:`await_args_list` list is checked for the awaits. + + If *any_order* is False (the default) then the awaits must be + sequential. There can be extra calls before or after the + specified awaits. + + If *any_order* is True then the awaits can be in any order, but + they must all appear in :attr:`await_args_list`. + + >>> mock = AsyncMock() + >>> async def main(*args, **kwargs): + ... await mock(*args, **kwargs) + ... + >>> calls = [call("foo"), call("bar")] + >>> mock.assert_has_calls(calls) + Traceback (most recent call last): + ... + AssertionError: Calls not found. + Expected: [call('foo'), call('bar')] + >>> asyncio.run(main('foo')) + >>> asyncio.run(main('bar')) + >>> mock.assert_has_calls(calls) + + .. method:: assert_not_awaited() + + Assert that the mock was never awaited. + + >>> mock = AsyncMock() + >>> mock.assert_not_awaited() + + .. method:: reset_mock(*args, **kwargs) + + See :func:`Mock.reset_mock`. Also sets :attr:`await_count` to 0, + :attr:`await_args` to None, and clears the :attr:`await_args_list`. + + .. attribute:: await_count + + An integer keeping track of how many times the mock object has been awaited. + + >>> mock = AsyncMock() + >>> async def main(): + ... await mock() + ... + >>> asyncio.run(main()) + >>> mock.await_count + 1 + >>> asyncio.run(main()) + >>> mock.await_count + 2 + + .. attribute:: await_args + + This is either ``None`` (if the mock hasn’t been awaited), or the arguments that + the mock was last awaited with. Functions the same as :attr:`Mock.call_args`. + + >>> mock = AsyncMock() + >>> async def main(*args): + ... await mock(*args) + ... + >>> mock.await_args + >>> asyncio.run(main('foo')) + >>> mock.await_args + call('foo') + >>> asyncio.run(main('bar')) + >>> mock.await_args + call('bar') + + + .. attribute:: await_args_list + + This is a list of all the awaits made to the mock object in sequence (so the + length of the list is the number of times it has been awaited). Before any + awaits have been made it is an empty list. + + >>> mock = AsyncMock() + >>> async def main(*args): + ... await mock(*args) + ... + >>> mock.await_args_list + [] + >>> asyncio.run(main('foo')) + >>> mock.await_args_list + [call('foo')] + >>> asyncio.run(main('bar')) + >>> mock.await_args_list + [call('foo'), call('bar')] + + Calling ~~~~~~~ diff --git a/Doc/whatsnew/3.8.rst b/Doc/whatsnew/3.8.rst index d6388f8faaba4f..f3f9cedb5ce217 100644 --- a/Doc/whatsnew/3.8.rst +++ b/Doc/whatsnew/3.8.rst @@ -508,6 +508,10 @@ unicodedata unittest -------- +* XXX Added :class:`AsyncMock` to support an asynchronous version of :class:`Mock`. + Appropriate new assert functions for testing have been added as well. + (Contributed by Lisa Roach in :issue:`26467`). + * Added :func:`~unittest.addModuleCleanup()` and :meth:`~unittest.TestCase.addClassCleanup()` to unittest to support cleanups for :func:`~unittest.setUpModule()` and diff --git a/Lib/unittest/mock.py b/Lib/unittest/mock.py index 1e8057d5f5bb5f..2ebbcbe369631b 100644 --- a/Lib/unittest/mock.py +++ b/Lib/unittest/mock.py @@ -13,6 +13,7 @@ 'ANY', 'call', 'create_autospec', + 'AsyncMock', 'FILTER_DIR', 'NonCallableMock', 'NonCallableMagicMock', @@ -24,13 +25,13 @@ __version__ = '1.0' - +import asyncio import io import inspect import pprint import sys import builtins -from types import ModuleType, MethodType +from types import CodeType, ModuleType, MethodType from unittest.util import safe_repr from functools import wraps, partial @@ -43,6 +44,13 @@ # Without this, the __class__ properties wouldn't be set correctly _safe_super = super +def _is_async_obj(obj): + if getattr(obj, '__code__', None): + return asyncio.iscoroutinefunction(obj) or inspect.isawaitable(obj) + else: + return False + + def _is_instance_mock(obj): # can't use isinstance on Mock objects because they override __class__ # The base class for all mocks is NonCallableMock @@ -355,7 +363,20 @@ def __new__(cls, *args, **kw): # every instance has its own class # so we can create magic methods on the # class without stomping on other mocks - new = type(cls.__name__, (cls,), {'__doc__': cls.__doc__}) + bases = (cls,) + if not issubclass(cls, AsyncMock): + # Check if spec is an async object or function + sig = inspect.signature(NonCallableMock.__init__) + bound_args = sig.bind_partial(cls, *args, **kw).arguments + spec_arg = [ + arg for arg in bound_args.keys() + if arg.startswith('spec') + ] + if spec_arg: + # what if spec_set is different than spec? + if _is_async_obj(bound_args[spec_arg[0]]): + bases = (AsyncMockMixin, cls,) + new = type(cls.__name__, bases, {'__doc__': cls.__doc__}) instance = object.__new__(new) return instance @@ -431,6 +452,11 @@ def _mock_add_spec(self, spec, spec_set, _spec_as_instance=False, _eat_self=False): _spec_class = None _spec_signature = None + _spec_asyncs = [] + + for attr in dir(spec): + if asyncio.iscoroutinefunction(getattr(spec, attr, None)): + _spec_asyncs.append(attr) if spec is not None and not _is_list(spec): if isinstance(spec, type): @@ -448,7 +474,7 @@ def _mock_add_spec(self, spec, spec_set, _spec_as_instance=False, __dict__['_spec_set'] = spec_set __dict__['_spec_signature'] = _spec_signature __dict__['_mock_methods'] = spec - + __dict__['_spec_asyncs'] = _spec_asyncs def __get_return_value(self): ret = self._mock_return_value @@ -885,7 +911,15 @@ def _get_child_mock(self, **kw): For non-callable mocks the callable variant will be used (rather than any custom subclass).""" + _new_name = kw.get("_new_name") + if _new_name in self.__dict__['_spec_asyncs']: + return AsyncMock(**kw) + _type = type(self) + if issubclass(_type, MagicMock) and _new_name in _async_method_magics: + klass = AsyncMock + if issubclass(_type, AsyncMockMixin): + klass = MagicMock if not issubclass(_type, CallableMixin): if issubclass(_type, NonCallableMagicMock): klass = MagicMock @@ -931,14 +965,12 @@ def _try_iter(obj): return obj - class CallableMixin(Base): def __init__(self, spec=None, side_effect=None, return_value=DEFAULT, wraps=None, name=None, spec_set=None, parent=None, _spec_state=None, _new_name='', _new_parent=None, **kwargs): self.__dict__['_mock_return_value'] = return_value - _safe_super(CallableMixin, self).__init__( spec, wraps, name, spec_set, parent, _spec_state, _new_name, _new_parent, **kwargs @@ -1080,7 +1112,6 @@ class or instance) that acts as the specification for the mock object. If """ - def _dot_lookup(thing, comp, import_path): try: return getattr(thing, comp) @@ -1278,8 +1309,10 @@ def __enter__(self): if isinstance(original, type): # If we're patching out a class and there is a spec inherit = True - - Klass = MagicMock + if spec is None and _is_async_obj(original): + Klass = AsyncMock + else: + Klass = MagicMock _kwargs = {} if new_callable is not None: Klass = new_callable @@ -1291,7 +1324,9 @@ def __enter__(self): not_callable = '__call__' not in this_spec else: not_callable = not callable(this_spec) - if not_callable: + if _is_async_obj(this_spec): + Klass = AsyncMock + elif not_callable: Klass = NonCallableMagicMock if spec is not None: @@ -1732,7 +1767,7 @@ def _patch_stopall(): '__reduce__', '__reduce_ex__', '__getinitargs__', '__getnewargs__', '__getstate__', '__setstate__', '__getformat__', '__setformat__', '__repr__', '__dir__', '__subclasses__', '__format__', - '__getnewargs_ex__', + '__getnewargs_ex__', '__aenter__', '__aexit__', '__anext__', '__aiter__', } @@ -1749,6 +1784,11 @@ def method(self, *args, **kw): ' '.join([magic_methods, numerics, inplace, right]).split() } +# Magic methods used for async `with` statements +_async_method_magics = {"__aenter__", "__aexit__", "__anext__"} +# `__aiter__` is a plain function but used with async calls +_async_magics = _async_method_magics | {"__aiter__"} + _all_magics = _magics | _non_defaults _unsupported_magics = { @@ -1778,6 +1818,7 @@ def method(self, *args, **kw): '__float__': 1.0, '__bool__': True, '__index__': 1, + '__aexit__': False, } @@ -1810,10 +1851,19 @@ def __iter__(): return iter(ret_val) return __iter__ +def _get_async_iter(self): + def __aiter__(): + ret_val = self.__aiter__._mock_return_value + if ret_val is DEFAULT: + return _AsyncIterator(iter([])) + return _AsyncIterator(iter(ret_val)) + return __aiter__ + _side_effect_methods = { '__eq__': _get_eq, '__ne__': _get_ne, '__iter__': _get_iter, + '__aiter__': _get_async_iter } @@ -1878,8 +1928,33 @@ def mock_add_spec(self, spec, spec_set=False): self._mock_set_magics() +class AsyncMagicMixin: + def __init__(self, *args, **kw): + self._mock_set_async_magics() # make magic work for kwargs in init + _safe_super(AsyncMagicMixin, self).__init__(*args, **kw) + self._mock_set_async_magics() # fix magic broken by upper level init + + def _mock_set_async_magics(self): + these_magics = _async_magics -class MagicMock(MagicMixin, Mock): + if getattr(self, "_mock_methods", None) is not None: + these_magics = _async_magics.intersection(self._mock_methods) + remove_magics = _async_magics - these_magics + + for entry in remove_magics: + if entry in type(self).__dict__: + # remove unneeded magic methods + delattr(self, entry) + + # don't overwrite existing attributes if called a second time + these_magics = these_magics - set(type(self).__dict__) + + _type = type(self) + for entry in these_magics: + setattr(_type, entry, MagicProxy(entry, self)) + + +class MagicMock(MagicMixin, AsyncMagicMixin, Mock): """ MagicMock is a subclass of Mock with default implementations of most of the magic methods. You can use MagicMock without having to @@ -1919,6 +1994,218 @@ def __get__(self, obj, _type=None): return self.create_mock() +class AsyncMockMixin(Base): + awaited = _delegating_property('awaited') + await_count = _delegating_property('await_count') + await_args = _delegating_property('await_args') + await_args_list = _delegating_property('await_args_list') + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # asyncio.iscoroutinefunction() checks _is_coroutine property to say if an + # object is a coroutine. Without this check it looks to see if it is a + # function/method, which in this case it is not (since it is an + # AsyncMock). + # It is set through __dict__ because when spec_set is True, this + # attribute is likely undefined. + self.__dict__['_is_coroutine'] = asyncio.coroutines._is_coroutine + self.__dict__['_mock_awaited'] = _AwaitEvent(self) + self.__dict__['_mock_await_count'] = 0 + self.__dict__['_mock_await_args'] = None + self.__dict__['_mock_await_args_list'] = _CallList() + code_mock = NonCallableMock(spec_set=CodeType) + code_mock.co_flags = inspect.CO_COROUTINE + self.__dict__['__code__'] = code_mock + + async def _mock_call(_mock_self, *args, **kwargs): + self = _mock_self + try: + result = super()._mock_call(*args, **kwargs) + except (BaseException, StopIteration) as e: + side_effect = self.side_effect + if side_effect is not None and not callable(side_effect): + raise + return await _raise(e) + + _call = self.call_args + + async def proxy(): + try: + if inspect.isawaitable(result): + return await result + else: + return result + finally: + self.await_count += 1 + self.await_args = _call + self.await_args_list.append(_call) + await self.awaited._notify() + + return await proxy() + + def assert_awaited(_mock_self): + """ + Assert that the mock was awaited at least once. + """ + self = _mock_self + if self.await_count == 0: + msg = f"Expected {self._mock_name or 'mock'} to have been awaited." + raise AssertionError(msg) + + def assert_awaited_once(_mock_self): + """ + Assert that the mock was awaited exactly once. + """ + self = _mock_self + if not self.await_count == 1: + msg = (f"Expected {self._mock_name or 'mock'} to have been awaited once." + f" Awaited {self.await_count} times.") + raise AssertionError(msg) + + def assert_awaited_with(_mock_self, *args, **kwargs): + """ + Assert that the last await was with the specified arguments. + """ + self = _mock_self + if self.await_args is None: + expected = self._format_mock_call_signature(args, kwargs) + raise AssertionError(f'Expected await: {expected}\nNot awaited') + + def _error_message(): + msg = self._format_mock_failure_message(args, kwargs) + return msg + + expected = self._call_matcher((args, kwargs)) + actual = self._call_matcher(self.await_args) + if expected != actual: + cause = expected if isinstance(expected, Exception) else None + raise AssertionError(_error_message()) from cause + + def assert_awaited_once_with(_mock_self, *args, **kwargs): + """ + Assert that the mock was awaited exactly once and with the specified + arguments. + """ + self = _mock_self + if not self.await_count == 1: + msg = (f"Expected {self._mock_name or 'mock'} to have been awaited once." + f" Awaited {self.await_count} times.") + raise AssertionError(msg) + return self.assert_awaited_with(*args, **kwargs) + + def assert_any_await(_mock_self, *args, **kwargs): + """ + Assert the mock has ever been awaited with the specified arguments. + """ + self = _mock_self + expected = self._call_matcher((args, kwargs)) + actual = [self._call_matcher(c) for c in self.await_args_list] + if expected not in actual: + cause = expected if isinstance(expected, Exception) else None + expected_string = self._format_mock_call_signature(args, kwargs) + raise AssertionError( + '%s await not found' % expected_string + ) from cause + + def assert_has_awaits(_mock_self, calls, any_order=False): + """ + Assert the mock has been awaited with the specified calls. + The :attr:`await_args_list` list is checked for the awaits. + + If `any_order` is False (the default) then the awaits must be + sequential. There can be extra calls before or after the + specified awaits. + + If `any_order` is True then the awaits can be in any order, but + they must all appear in :attr:`await_args_list`. + """ + self = _mock_self + expected = [self._call_matcher(c) for c in calls] + cause = expected if isinstance(expected, Exception) else None + all_awaits = _CallList(self._call_matcher(c) for c in self.await_args_list) + if not any_order: + if expected not in all_awaits: + raise AssertionError( + f'Awaits not found.\nExpected: {_CallList(calls)}\n', + f'Actual: {self.await_args_list}' + ) from cause + return + + all_awaits = list(all_awaits) + + not_found = [] + for kall in expected: + try: + all_awaits.remove(kall) + except ValueError: + not_found.append(kall) + if not_found: + raise AssertionError( + '%r not all found in await list' % (tuple(not_found),) + ) from cause + + def assert_not_awaited(_mock_self): + """ + Assert that the mock was never awaited. + """ + self = _mock_self + if self.await_count != 0: + msg = (f"Expected {self._mock_name or 'mock'} to have been awaited once." + f" Awaited {self.await_count} times.") + raise AssertionError(msg) + + def reset_mock(self, *args, **kwargs): + """ + See :func:`.Mock.reset_mock()` + """ + super().reset_mock(*args, **kwargs) + self.await_count = 0 + self.await_args = None + self.await_args_list = _CallList() + + +class AsyncMock(AsyncMockMixin, AsyncMagicMixin, Mock): + """ + Enhance :class:`Mock` with features allowing to mock + an async function. + + The :class:`AsyncMock` object will behave so the object is + recognized as an async function, and the result of a call is an awaitable: + + >>> mock = AsyncMock() + >>> asyncio.iscoroutinefunction(mock) + True + >>> inspect.isawaitable(mock()) + True + + + The result of ``mock()`` is an async function which will have the outcome + of ``side_effect`` or ``return_value``: + + - if ``side_effect`` is a function, the async function will return the + result of that function, + - if ``side_effect`` is an exception, the async function will raise the + exception, + - if ``side_effect`` is an iterable, the async function will return the + next value of the iterable, however, if the sequence of result is + exhausted, ``StopIteration`` is raised immediately, + - if ``side_effect`` is not defined, the async function will return the + value defined by ``return_value``, hence, by default, the async function + returns a new :class:`AsyncMock` object. + + If the outcome of ``side_effect`` or ``return_value`` is an async function, + the mock async function obtained when the mock object is called will be this + async function itself (and not an async function returning an async + function). + + The test author can also specify a wrapped object with ``wraps``. In this + case, the :class:`Mock` object behavior is the same as with an + :class:`.Mock` object: the wrapped object may have methods + defined as async function functions. + + Based on Martin Richard's asyntest project. + """ + class _ANY(object): "A helper object that compares equal to everything." @@ -2144,7 +2431,6 @@ def call_list(self): call = _Call(from_kall=False) - def create_autospec(spec, spec_set=False, instance=False, _parent=None, _name=None, **kwargs): """Create a mock object using another object as a spec. Attributes on the @@ -2170,7 +2456,10 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None, spec = type(spec) is_type = isinstance(spec, type) - + if getattr(spec, '__code__', None): + is_async_func = asyncio.iscoroutinefunction(spec) + else: + is_async_func = False _kwargs = {'spec': spec} if spec_set: _kwargs = {'spec_set': spec} @@ -2187,6 +2476,11 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None, # descriptors don't have a spec # because we don't know what type they return _kwargs = {} + elif is_async_func: + if instance: + raise RuntimeError("Instance can not be True when create_autospec " + "is mocking an async function") + Klass = AsyncMock elif not _callable(spec): Klass = NonCallableMagicMock elif is_type and instance and not _instance_callable(spec): @@ -2203,9 +2497,26 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None, name=_name, **_kwargs) if isinstance(spec, FunctionTypes): + wrapped_mock = mock # should only happen at the top level because we don't # recurse for functions mock = _set_signature(mock, spec) + if is_async_func: + mock._is_coroutine = asyncio.coroutines._is_coroutine + mock.await_count = 0 + mock.await_args = None + mock.await_args_list = _CallList() + + for a in ('assert_awaited', + 'assert_awaited_once', + 'assert_awaited_with', + 'assert_awaited_once_with', + 'assert_any_await', + 'assert_has_awaits', + 'assert_not_awaited'): + def f(*args, **kwargs): + return getattr(wrapped_mock, a)(*args, **kwargs) + setattr(mock, a, f) else: _check_signature(spec, mock, is_type, instance) @@ -2249,9 +2560,13 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None, skipfirst = _must_skip(spec, entry, is_type) kwargs['_eat_self'] = skipfirst - new = MagicMock(parent=parent, name=entry, _new_name=entry, - _new_parent=parent, - **kwargs) + if asyncio.iscoroutinefunction(original): + child_klass = AsyncMock + else: + child_klass = MagicMock + new = child_klass(parent=parent, name=entry, _new_name=entry, + _new_parent=parent, + **kwargs) mock._mock_children[entry] = new _check_signature(original, new, skipfirst=skipfirst) @@ -2437,3 +2752,60 @@ def seal(mock): continue if m._mock_new_parent is mock: seal(m) + + +async def _raise(exception): + raise exception + + +class _AsyncIterator: + """ + Wraps an iterator in an asynchronous iterator. + """ + def __init__(self, iterator): + self.iterator = iterator + code_mock = NonCallableMock(spec_set=CodeType) + code_mock.co_flags = inspect.CO_ITERABLE_COROUTINE + self.__dict__['__code__'] = code_mock + + def __aiter__(self): + return self + + async def __anext__(self): + try: + return next(self.iterator) + except StopIteration: + pass + raise StopAsyncIteration + + +class _AwaitEvent: + def __init__(self, mock): + self._mock = mock + self._condition = None + + async def _notify(self): + condition = self._get_condition() + try: + await condition.acquire() + condition.notify_all() + finally: + condition.release() + + def _get_condition(self): + """ + Creation of condition is delayed, to minimize the chance of using the + wrong loop. + A user may create a mock with _AwaitEvent before selecting the + execution loop. Requiring a user to delay creation is error-prone and + inflexible. Instead, condition is created when user actually starts to + use the mock. + """ + # No synchronization is needed: + # - asyncio is thread unsafe + # - there are no awaits here, method will be executed without + # switching asyncio context. + if self._condition is None: + self._condition = asyncio.Condition() + + return self._condition diff --git a/Lib/unittest/test/testmock/testasync.py b/Lib/unittest/test/testmock/testasync.py new file mode 100644 index 00000000000000..a9aa1434b963f1 --- /dev/null +++ b/Lib/unittest/test/testmock/testasync.py @@ -0,0 +1,549 @@ +import asyncio +import inspect +import unittest + +from unittest.mock import call, AsyncMock, patch, MagicMock, create_autospec + + +def tearDownModule(): + asyncio.set_event_loop_policy(None) + + +class AsyncClass: + def __init__(self): + pass + async def async_method(self): + pass + def normal_method(self): + pass + +async def async_func(): + pass + +def normal_func(): + pass + +class NormalClass(object): + def a(self): + pass + + +async_foo_name = f'{__name__}.AsyncClass' +normal_foo_name = f'{__name__}.NormalClass' + + +class AsyncPatchDecoratorTest(unittest.TestCase): + def test_is_coroutine_function_patch(self): + @patch.object(AsyncClass, 'async_method') + def test_async(mock_method): + self.assertTrue(asyncio.iscoroutinefunction(mock_method)) + test_async() + + def test_is_async_patch(self): + @patch.object(AsyncClass, 'async_method') + def test_async(mock_method): + m = mock_method() + self.assertTrue(inspect.isawaitable(m)) + asyncio.run(m) + + @patch(f'{async_foo_name}.async_method') + def test_no_parent_attribute(mock_method): + m = mock_method() + self.assertTrue(inspect.isawaitable(m)) + asyncio.run(m) + + test_async() + test_no_parent_attribute() + + def test_is_AsyncMock_patch(self): + @patch.object(AsyncClass, 'async_method') + def test_async(mock_method): + self.assertIsInstance(mock_method, AsyncMock) + + test_async() + + +class AsyncPatchCMTest(unittest.TestCase): + def test_is_async_function_cm(self): + def test_async(): + with patch.object(AsyncClass, 'async_method') as mock_method: + self.assertTrue(asyncio.iscoroutinefunction(mock_method)) + + test_async() + + def test_is_async_cm(self): + def test_async(): + with patch.object(AsyncClass, 'async_method') as mock_method: + m = mock_method() + self.assertTrue(inspect.isawaitable(m)) + asyncio.run(m) + + test_async() + + def test_is_AsyncMock_cm(self): + def test_async(): + with patch.object(AsyncClass, 'async_method') as mock_method: + self.assertIsInstance(mock_method, AsyncMock) + + test_async() + + +class AsyncMockTest(unittest.TestCase): + def test_iscoroutinefunction_default(self): + mock = AsyncMock() + self.assertTrue(asyncio.iscoroutinefunction(mock)) + + def test_iscoroutinefunction_function(self): + async def foo(): pass + mock = AsyncMock(foo) + self.assertTrue(asyncio.iscoroutinefunction(mock)) + self.assertTrue(inspect.iscoroutinefunction(mock)) + + def test_isawaitable(self): + mock = AsyncMock() + m = mock() + self.assertTrue(inspect.isawaitable(m)) + asyncio.run(m) + self.assertIn('assert_awaited', dir(mock)) + + def test_iscoroutinefunction_normal_function(self): + def foo(): pass + mock = AsyncMock(foo) + self.assertTrue(asyncio.iscoroutinefunction(mock)) + self.assertTrue(inspect.iscoroutinefunction(mock)) + + def test_future_isfuture(self): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + fut = asyncio.Future() + loop.stop() + loop.close() + mock = AsyncMock(fut) + self.assertIsInstance(mock, asyncio.Future) + + +class AsyncAutospecTest(unittest.TestCase): + def test_is_AsyncMock_patch(self): + @patch(async_foo_name, autospec=True) + def test_async(mock_method): + self.assertIsInstance(mock_method.async_method, AsyncMock) + self.assertIsInstance(mock_method, MagicMock) + + @patch(async_foo_name, autospec=True) + def test_normal_method(mock_method): + self.assertIsInstance(mock_method.normal_method, MagicMock) + + test_async() + test_normal_method() + + def test_create_autospec_instance(self): + with self.assertRaises(RuntimeError): + create_autospec(async_func, instance=True) + + def test_create_autospec(self): + spec = create_autospec(async_func) + self.assertTrue(asyncio.iscoroutinefunction(spec)) + + +class AsyncSpecTest(unittest.TestCase): + def test_spec_as_async_positional_magicmock(self): + mock = MagicMock(async_func) + self.assertIsInstance(mock, MagicMock) + m = mock() + self.assertTrue(inspect.isawaitable(m)) + asyncio.run(m) + + def test_spec_as_async_kw_magicmock(self): + mock = MagicMock(spec=async_func) + self.assertIsInstance(mock, MagicMock) + m = mock() + self.assertTrue(inspect.isawaitable(m)) + asyncio.run(m) + + def test_spec_as_async_kw_AsyncMock(self): + mock = AsyncMock(spec=async_func) + self.assertIsInstance(mock, AsyncMock) + m = mock() + self.assertTrue(inspect.isawaitable(m)) + asyncio.run(m) + + def test_spec_as_async_positional_AsyncMock(self): + mock = AsyncMock(async_func) + self.assertIsInstance(mock, AsyncMock) + m = mock() + self.assertTrue(inspect.isawaitable(m)) + asyncio.run(m) + + def test_spec_as_normal_kw_AsyncMock(self): + mock = AsyncMock(spec=normal_func) + self.assertIsInstance(mock, AsyncMock) + m = mock() + self.assertTrue(inspect.isawaitable(m)) + asyncio.run(m) + + def test_spec_as_normal_positional_AsyncMock(self): + mock = AsyncMock(normal_func) + self.assertIsInstance(mock, AsyncMock) + m = mock() + self.assertTrue(inspect.isawaitable(m)) + asyncio.run(m) + + def test_spec_async_mock(self): + @patch.object(AsyncClass, 'async_method', spec=True) + def test_async(mock_method): + self.assertIsInstance(mock_method, AsyncMock) + + test_async() + + def test_spec_parent_not_async_attribute_is(self): + @patch(async_foo_name, spec=True) + def test_async(mock_method): + self.assertIsInstance(mock_method, MagicMock) + self.assertIsInstance(mock_method.async_method, AsyncMock) + + test_async() + + def test_target_async_spec_not(self): + @patch.object(AsyncClass, 'async_method', spec=NormalClass.a) + def test_async_attribute(mock_method): + self.assertIsInstance(mock_method, MagicMock) + self.assertFalse(inspect.iscoroutine(mock_method)) + self.assertFalse(inspect.isawaitable(mock_method)) + + test_async_attribute() + + def test_target_not_async_spec_is(self): + @patch.object(NormalClass, 'a', spec=async_func) + def test_attribute_not_async_spec_is(mock_async_func): + self.assertIsInstance(mock_async_func, AsyncMock) + test_attribute_not_async_spec_is() + + def test_spec_async_attributes(self): + @patch(normal_foo_name, spec=AsyncClass) + def test_async_attributes_coroutines(MockNormalClass): + self.assertIsInstance(MockNormalClass.async_method, AsyncMock) + self.assertIsInstance(MockNormalClass, MagicMock) + + test_async_attributes_coroutines() + + +class AsyncSpecSetTest(unittest.TestCase): + def test_is_AsyncMock_patch(self): + @patch.object(AsyncClass, 'async_method', spec_set=True) + def test_async(async_method): + self.assertIsInstance(async_method, AsyncMock) + + def test_is_async_AsyncMock(self): + mock = AsyncMock(spec_set=AsyncClass.async_method) + self.assertTrue(asyncio.iscoroutinefunction(mock)) + self.assertIsInstance(mock, AsyncMock) + + def test_is_child_AsyncMock(self): + mock = MagicMock(spec_set=AsyncClass) + self.assertTrue(asyncio.iscoroutinefunction(mock.async_method)) + self.assertFalse(asyncio.iscoroutinefunction(mock.normal_method)) + self.assertIsInstance(mock.async_method, AsyncMock) + self.assertIsInstance(mock.normal_method, MagicMock) + self.assertIsInstance(mock, MagicMock) + + +class AsyncArguments(unittest.TestCase): + def test_add_return_value(self): + async def addition(self, var): + return var + 1 + + mock = AsyncMock(addition, return_value=10) + output = asyncio.run(mock(5)) + + self.assertEqual(output, 10) + + def test_add_side_effect_exception(self): + async def addition(var): + return var + 1 + mock = AsyncMock(addition, side_effect=Exception('err')) + with self.assertRaises(Exception): + asyncio.run(mock(5)) + + def test_add_side_effect_function(self): + async def addition(var): + return var + 1 + mock = AsyncMock(side_effect=addition) + result = asyncio.run(mock(5)) + self.assertEqual(result, 6) + + def test_add_side_effect_iterable(self): + vals = [1, 2, 3] + mock = AsyncMock(side_effect=vals) + for item in vals: + self.assertEqual(item, asyncio.run(mock())) + + with self.assertRaises(RuntimeError) as e: + asyncio.run(mock()) + self.assertEqual( + e.exception, + RuntimeError('coroutine raised StopIteration') + ) + + +class AsyncContextManagerTest(unittest.TestCase): + class WithAsyncContextManager: + def __init__(self): + self.entered = False + self.exited = False + + async def __aenter__(self, *args, **kwargs): + self.entered = True + return self + + async def __aexit__(self, *args, **kwargs): + self.exited = True + + def test_magic_methods_are_async_mocks(self): + mock = MagicMock(self.WithAsyncContextManager()) + self.assertIsInstance(mock.__aenter__, AsyncMock) + self.assertIsInstance(mock.__aexit__, AsyncMock) + + def test_mock_supports_async_context_manager(self): + called = False + instance = self.WithAsyncContextManager() + mock_instance = MagicMock(instance) + + async def use_context_manager(): + nonlocal called + async with mock_instance as result: + called = True + return result + + result = asyncio.run(use_context_manager()) + self.assertFalse(instance.entered) + self.assertFalse(instance.exited) + self.assertTrue(called) + self.assertTrue(mock_instance.entered) + self.assertTrue(mock_instance.exited) + self.assertTrue(mock_instance.__aenter__.called) + self.assertTrue(mock_instance.__aexit__.called) + self.assertIsNot(mock_instance, result) + self.assertIsInstance(result, AsyncMock) + + def test_mock_customize_async_context_manager(self): + instance = self.WithAsyncContextManager() + mock_instance = MagicMock(instance) + + expected_result = object() + mock_instance.__aenter__.return_value = expected_result + + async def use_context_manager(): + async with mock_instance as result: + return result + + self.assertIs(asyncio.run(use_context_manager()), expected_result) + + def test_mock_customize_async_context_manager_with_coroutine(self): + enter_called = False + exit_called = False + + async def enter_coroutine(*args): + nonlocal enter_called + enter_called = True + + async def exit_coroutine(*args): + nonlocal exit_called + exit_called = True + + instance = self.WithAsyncContextManager() + mock_instance = MagicMock(instance) + + mock_instance.__aenter__ = enter_coroutine + mock_instance.__aexit__ = exit_coroutine + + async def use_context_manager(): + async with mock_instance: + pass + + asyncio.run(use_context_manager()) + self.assertTrue(enter_called) + self.assertTrue(exit_called) + + def test_context_manager_raise_exception_by_default(self): + async def raise_in(context_manager): + async with context_manager: + raise TypeError() + + instance = self.WithAsyncContextManager() + mock_instance = MagicMock(instance) + with self.assertRaises(TypeError): + asyncio.run(raise_in(mock_instance)) + + +class AsyncIteratorTest(unittest.TestCase): + class WithAsyncIterator(object): + def __init__(self): + self.items = ["foo", "NormalFoo", "baz"] + + def __aiter__(self): + return self + + async def __anext__(self): + try: + return self.items.pop() + except IndexError: + pass + + raise StopAsyncIteration + + def test_mock_aiter_and_anext(self): + instance = self.WithAsyncIterator() + mock_instance = MagicMock(instance) + + self.assertEqual(asyncio.iscoroutine(instance.__aiter__), + asyncio.iscoroutine(mock_instance.__aiter__)) + self.assertEqual(asyncio.iscoroutine(instance.__anext__), + asyncio.iscoroutine(mock_instance.__anext__)) + + iterator = instance.__aiter__() + if asyncio.iscoroutine(iterator): + iterator = asyncio.run(iterator) + + mock_iterator = mock_instance.__aiter__() + if asyncio.iscoroutine(mock_iterator): + mock_iterator = asyncio.run(mock_iterator) + + self.assertEqual(asyncio.iscoroutine(iterator.__aiter__), + asyncio.iscoroutine(mock_iterator.__aiter__)) + self.assertEqual(asyncio.iscoroutine(iterator.__anext__), + asyncio.iscoroutine(mock_iterator.__anext__)) + + def test_mock_async_for(self): + async def iterate(iterator): + accumulator = [] + async for item in iterator: + accumulator.append(item) + + return accumulator + + expected = ["FOO", "BAR", "BAZ"] + with self.subTest("iterate through default value"): + mock_instance = MagicMock(self.WithAsyncIterator()) + self.assertEqual([], asyncio.run(iterate(mock_instance))) + + with self.subTest("iterate through set return_value"): + mock_instance = MagicMock(self.WithAsyncIterator()) + mock_instance.__aiter__.return_value = expected[:] + self.assertEqual(expected, asyncio.run(iterate(mock_instance))) + + with self.subTest("iterate through set return_value iterator"): + mock_instance = MagicMock(self.WithAsyncIterator()) + mock_instance.__aiter__.return_value = iter(expected[:]) + self.assertEqual(expected, asyncio.run(iterate(mock_instance))) + + +class AsyncMockAssert(unittest.TestCase): + def setUp(self): + self.mock = AsyncMock() + + async def _runnable_test(self, *args): + if not args: + await self.mock() + else: + await self.mock(*args) + + def test_assert_awaited(self): + with self.assertRaises(AssertionError): + self.mock.assert_awaited() + + asyncio.run(self._runnable_test()) + self.mock.assert_awaited() + + def test_assert_awaited_once(self): + with self.assertRaises(AssertionError): + self.mock.assert_awaited_once() + + asyncio.run(self._runnable_test()) + self.mock.assert_awaited_once() + + asyncio.run(self._runnable_test()) + with self.assertRaises(AssertionError): + self.mock.assert_awaited_once() + + def test_assert_awaited_with(self): + asyncio.run(self._runnable_test()) + with self.assertRaises(AssertionError): + self.mock.assert_awaited_with('foo') + + asyncio.run(self._runnable_test('foo')) + self.mock.assert_awaited_with('foo') + + asyncio.run(self._runnable_test('SomethingElse')) + with self.assertRaises(AssertionError): + self.mock.assert_awaited_with('foo') + + def test_assert_awaited_once_with(self): + with self.assertRaises(AssertionError): + self.mock.assert_awaited_once_with('foo') + + asyncio.run(self._runnable_test('foo')) + self.mock.assert_awaited_once_with('foo') + + asyncio.run(self._runnable_test('foo')) + with self.assertRaises(AssertionError): + self.mock.assert_awaited_once_with('foo') + + def test_assert_any_wait(self): + with self.assertRaises(AssertionError): + self.mock.assert_any_await('NormalFoo') + + asyncio.run(self._runnable_test('foo')) + with self.assertRaises(AssertionError): + self.mock.assert_any_await('NormalFoo') + + asyncio.run(self._runnable_test('NormalFoo')) + self.mock.assert_any_await('NormalFoo') + + asyncio.run(self._runnable_test('SomethingElse')) + self.mock.assert_any_await('NormalFoo') + + def test_assert_has_awaits_no_order(self): + calls = [call('NormalFoo'), call('baz')] + + with self.assertRaises(AssertionError): + self.mock.assert_has_awaits(calls) + + asyncio.run(self._runnable_test('foo')) + with self.assertRaises(AssertionError): + self.mock.assert_has_awaits(calls) + + asyncio.run(self._runnable_test('NormalFoo')) + with self.assertRaises(AssertionError): + self.mock.assert_has_awaits(calls) + + asyncio.run(self._runnable_test('baz')) + self.mock.assert_has_awaits(calls) + + asyncio.run(self._runnable_test('SomethingElse')) + self.mock.assert_has_awaits(calls) + + def test_assert_has_awaits_ordered(self): + calls = [call('NormalFoo'), call('baz')] + with self.assertRaises(AssertionError): + self.mock.assert_has_awaits(calls, any_order=True) + + asyncio.run(self._runnable_test('baz')) + with self.assertRaises(AssertionError): + self.mock.assert_has_awaits(calls, any_order=True) + + asyncio.run(self._runnable_test('foo')) + with self.assertRaises(AssertionError): + self.mock.assert_has_awaits(calls, any_order=True) + + asyncio.run(self._runnable_test('NormalFoo')) + self.mock.assert_has_awaits(calls, any_order=True) + + asyncio.run(self._runnable_test('qux')) + self.mock.assert_has_awaits(calls, any_order=True) + + def test_assert_not_awaited(self): + self.mock.assert_not_awaited() + + asyncio.run(self._runnable_test()) + with self.assertRaises(AssertionError): + self.mock.assert_not_awaited() diff --git a/Lib/unittest/test/testmock/testmock.py b/Lib/unittest/test/testmock/testmock.py index 5f917dd20f1dde..88c00b4e85d4f7 100644 --- a/Lib/unittest/test/testmock/testmock.py +++ b/Lib/unittest/test/testmock/testmock.py @@ -9,7 +9,7 @@ from unittest.mock import ( call, DEFAULT, patch, sentinel, MagicMock, Mock, NonCallableMock, - NonCallableMagicMock, _Call, _CallList, + NonCallableMagicMock, AsyncMock, _Call, _CallList, create_autospec ) @@ -1617,7 +1617,8 @@ def test_mock_add_spec_magic_methods(self): def test_adding_child_mock(self): - for Klass in NonCallableMock, Mock, MagicMock, NonCallableMagicMock: + for Klass in (NonCallableMock, Mock, MagicMock, NonCallableMagicMock, + AsyncMock): mock = Klass() mock.foo = Mock() diff --git a/Misc/NEWS.d/next/Library/2018-09-13-20-33-24.bpo-26467.cahAk3.rst b/Misc/NEWS.d/next/Library/2018-09-13-20-33-24.bpo-26467.cahAk3.rst new file mode 100644 index 00000000000000..4cf3f2ae7ef7e6 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2018-09-13-20-33-24.bpo-26467.cahAk3.rst @@ -0,0 +1,2 @@ +Added AsyncMock to support using unittest to mock asyncio coroutines. +Patch by Lisa Roach. pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy