Skip to content

Commit 681abde

Browse files
mariocj89tirkarthi
andcommitted
mock: Add EventMock class
Add a new class that allows to wait for a call to happen by using `Event` objects. This mock class can be used to test and validate expectations of multithreading code. Co-authored-by: Karthikeyan Singaravelan <tir.karthi@gmail.com>
1 parent c7437e2 commit 681abde

File tree

4 files changed

+263
-0
lines changed

4 files changed

+263
-0
lines changed

Doc/library/unittest.mock.rst

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,8 +204,10 @@ The Mock Class
204204
import asyncio
205205
import inspect
206206
import unittest
207+
import threading
207208
from unittest.mock import sentinel, DEFAULT, ANY
208209
from unittest.mock import patch, call, Mock, MagicMock, PropertyMock, AsyncMock
210+
from unittest.mock import ThreadingMock
209211
from unittest.mock import mock_open
210212

211213
:class:`Mock` is a flexible mock object intended to replace the use of stubs and
@@ -1097,6 +1099,39 @@ object::
10971099
[call('foo'), call('bar')]
10981100

10991101

1102+
.. class:: ThreadingMock(spec=None, side_effect=None, return_value=DEFAULT, wraps=None, name=None, spec_set=None, unsafe=False, **kwargs)
1103+
1104+
A version of :class:`MagicMock` for multithreading tests. The
1105+
:class:`ThreadingMock` object provides extra methods to wait for a call to
1106+
happen on a different thread, rather than assert on it immediately.
1107+
1108+
.. method:: wait_until_called(mock_timeout=None)
1109+
1110+
Waits until the the mock is called.
1111+
If ``mock_timeout`` is set, after that number of seconds waiting,
1112+
it raises an :exc:`AssertionError`, waits forever otherwise.
1113+
1114+
>>> mock = ThreadingMock()
1115+
>>> thread = threading.Thread(target=mock)
1116+
>>> thread.start()
1117+
>>> mock.wait_until_called(mock_timeout=1)
1118+
>>> thread.join()
1119+
1120+
.. method:: wait_until_any_call(*args, mock_timeout=None, **kwargs)
1121+
1122+
Waits until the the mock is called with the specified arguments.
1123+
If ``mock_timeout`` is set, after that number of seconds waiting,
1124+
it raises an :exc:`AssertionError`, waits forever otherwise.
1125+
1126+
>>> mock = ThreadingMock()
1127+
>>> thread = threading.Thread(target=mock, args=(1,2,), kwargs={"arg": "thing"})
1128+
>>> thread.start()
1129+
>>> mock.wait_until_any_call(1, 2, arg="thing")
1130+
>>> thread.join()
1131+
1132+
.. versionadded:: 3.10
1133+
1134+
11001135
Calling
11011136
~~~~~~~
11021137

Lib/unittest/mock.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
'call',
1515
'create_autospec',
1616
'AsyncMock',
17+
'ThreadingMock',
1718
'FILTER_DIR',
1819
'NonCallableMock',
1920
'NonCallableMagicMock',
@@ -31,6 +32,7 @@
3132
import sys
3233
import builtins
3334
from asyncio import iscoroutinefunction
35+
import threading
3436
from types import CodeType, ModuleType, MethodType
3537
from unittest.util import safe_repr
3638
from functools import wraps, partial
@@ -2851,6 +2853,59 @@ def __set__(self, obj, val):
28512853
self(val)
28522854

28532855

2856+
class ThreadingMock(MagicMock):
2857+
"""
2858+
A mock that can be used to wait until on calls happening
2859+
in a different thread.
2860+
"""
2861+
2862+
def __init__(self, *args, **kwargs):
2863+
_safe_super(ThreadingMock, self).__init__(*args, **kwargs)
2864+
self.__dict__["_event"] = threading.Event()
2865+
self.__dict__["_expected_calls"] = []
2866+
self.__dict__["_events_lock"] = threading.Lock()
2867+
2868+
def __get_event(self, expected_args, expected_kwargs):
2869+
with self._events_lock:
2870+
for args, kwargs, event in self._expected_calls:
2871+
if (args, kwargs) == (expected_args, expected_kwargs):
2872+
return event
2873+
new_event = threading.Event()
2874+
self._expected_calls.append((expected_args, expected_kwargs, new_event))
2875+
return new_event
2876+
2877+
2878+
def _mock_call(self, *args, **kwargs):
2879+
ret_value = _safe_super(ThreadingMock, self)._mock_call(*args, **kwargs)
2880+
2881+
call_event = self.__get_event(args, kwargs)
2882+
call_event.set()
2883+
2884+
self._event.set()
2885+
2886+
return ret_value
2887+
2888+
def wait_until_called(self, mock_timeout=None):
2889+
"""Wait until the mock object is called.
2890+
2891+
`mock_timeout` - time to wait for in seconds, waits forever otherwise.
2892+
"""
2893+
if not self._event.wait(timeout=mock_timeout):
2894+
msg = (f"{self._mock_name or 'mock'} was not called before"
2895+
f" timeout({mock_timeout}).")
2896+
raise AssertionError(msg)
2897+
2898+
def wait_until_any_call(self, *args, mock_timeout=None, **kwargs):
2899+
"""Wait until the mock object is called with given args.
2900+
2901+
`mock_timeout` - time to wait for in seconds, waits forever otherwise.
2902+
"""
2903+
event = self.__get_event(args, kwargs)
2904+
if not event.wait(timeout=mock_timeout):
2905+
expected_string = self._format_mock_call_signature(args, kwargs)
2906+
raise AssertionError(f'{expected_string} call not found')
2907+
2908+
28542909
def seal(mock):
28552910
"""Disable the automatic generation of child mocks.
28562911
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
import time
2+
import unittest
3+
import concurrent.futures
4+
5+
from unittest.mock import patch, ThreadingMock, call
6+
7+
8+
class Something:
9+
10+
def method_1(self):
11+
pass
12+
13+
def method_2(self):
14+
pass
15+
16+
17+
class TestThreadingMock(unittest.TestCase):
18+
19+
def _call_after_delay(self, func, /, *args, **kwargs):
20+
time.sleep(kwargs.pop('delay'))
21+
func(*args, **kwargs)
22+
23+
24+
def setUp(self):
25+
self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=5)
26+
27+
def tearDown(self):
28+
self._executor.shutdown()
29+
30+
def run_async(self, func, /, *args, delay=0, **kwargs):
31+
self._executor.submit(self._call_after_delay, func, *args, **kwargs, delay=delay)
32+
33+
def _make_mock(self, *args, **kwargs):
34+
return ThreadingMock(*args, **kwargs)
35+
36+
def test_instance_check(self):
37+
waitable_mock = self._make_mock()
38+
39+
with patch(f'{__name__}.Something', waitable_mock):
40+
something = Something()
41+
42+
self.assertIsInstance(something.method_1, ThreadingMock)
43+
self.assertIsInstance(
44+
something.method_1().method_2(), ThreadingMock)
45+
46+
47+
def test_side_effect(self):
48+
waitable_mock = self._make_mock()
49+
50+
with patch(f'{__name__}.Something', waitable_mock):
51+
something = Something()
52+
something.method_1.side_effect = [1]
53+
54+
self.assertEqual(something.method_1(), 1)
55+
56+
57+
def test_spec(self):
58+
waitable_mock = self._make_mock(spec=Something)
59+
60+
with patch(f'{__name__}.Something', waitable_mock) as m:
61+
something = m()
62+
63+
self.assertIsInstance(something.method_1, ThreadingMock)
64+
self.assertIsInstance(
65+
something.method_1().method_2(), ThreadingMock)
66+
67+
with self.assertRaises(AttributeError):
68+
m.test
69+
70+
71+
def test_wait_until_called(self):
72+
waitable_mock = self._make_mock(spec=Something)
73+
74+
with patch(f'{__name__}.Something', waitable_mock):
75+
something = Something()
76+
self.run_async(something.method_1, delay=0.01)
77+
something.method_1.wait_until_called()
78+
something.method_1.assert_called_once()
79+
80+
81+
def test_wait_until_called_called_before(self):
82+
waitable_mock = self._make_mock(spec=Something)
83+
84+
with patch(f'{__name__}.Something', waitable_mock):
85+
something = Something()
86+
something.method_1()
87+
something.method_1.wait_until_called()
88+
something.method_1.assert_called_once()
89+
90+
91+
def test_wait_until_called_magic_method(self):
92+
waitable_mock = self._make_mock(spec=Something)
93+
94+
with patch(f'{__name__}.Something', waitable_mock):
95+
something = Something()
96+
self.run_async(something.method_1.__str__, delay=0.01)
97+
something.method_1.__str__.wait_until_called()
98+
something.method_1.__str__.assert_called_once()
99+
100+
101+
def test_wait_until_called_timeout(self):
102+
waitable_mock = self._make_mock(spec=Something)
103+
104+
with patch(f'{__name__}.Something', waitable_mock):
105+
something = Something()
106+
self.run_async(something.method_1, delay=0.2)
107+
with self.assertRaises(AssertionError):
108+
something.method_1.wait_until_called(mock_timeout=0.1)
109+
something.method_1.assert_not_called()
110+
111+
something.method_1.wait_until_called()
112+
something.method_1.assert_called_once()
113+
114+
115+
def test_wait_until_any_call_positional(self):
116+
waitable_mock = self._make_mock(spec=Something)
117+
118+
with patch(f'{__name__}.Something', waitable_mock):
119+
something = Something()
120+
self.run_async(something.method_1, 1, delay=0.1)
121+
self.run_async(something.method_1, 2, delay=0.2)
122+
self.run_async(something.method_1, 3, delay=0.3)
123+
self.assertNotIn(call(1), something.method_1.mock_calls)
124+
125+
something.method_1.wait_until_any_call(1)
126+
something.method_1.assert_called_once_with(1)
127+
self.assertNotIn(call(2), something.method_1.mock_calls)
128+
self.assertNotIn(call(3), something.method_1.mock_calls)
129+
130+
something.method_1.wait_until_any_call(3)
131+
self.assertIn(call(2), something.method_1.mock_calls)
132+
something.method_1.wait_until_any_call(2)
133+
134+
135+
def test_wait_until_any_call_keywords(self):
136+
waitable_mock = self._make_mock(spec=Something)
137+
138+
with patch(f'{__name__}.Something', waitable_mock):
139+
something = Something()
140+
self.run_async(something.method_1, a=1, delay=0.1)
141+
self.run_async(something.method_1, b=2, delay=0.2)
142+
self.run_async(something.method_1, c=3, delay=0.3)
143+
self.assertNotIn(call(a=1), something.method_1.mock_calls)
144+
145+
something.method_1.wait_until_any_call(a=1)
146+
something.method_1.assert_called_once_with(a=1)
147+
self.assertNotIn(call(b=2), something.method_1.mock_calls)
148+
self.assertNotIn(call(c=3), something.method_1.mock_calls)
149+
150+
something.method_1.wait_until_any_call(c=3)
151+
self.assertIn(call(b=2), something.method_1.mock_calls)
152+
something.method_1.wait_until_any_call(b=2)
153+
154+
def test_wait_until_any_call_no_argument(self):
155+
waitable_mock = self._make_mock(spec=Something)
156+
157+
with patch(f'{__name__}.Something', waitable_mock):
158+
something = Something()
159+
something.method_1(1)
160+
161+
something.method_1.assert_called_once_with(1)
162+
with self.assertRaises(AssertionError):
163+
something.method_1.wait_until_any_call(mock_timeout=0.1)
164+
165+
something.method_1()
166+
something.method_1.wait_until_any_call()
167+
168+
169+
if __name__ == "__main__":
170+
unittest.main()
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Add `ThreadingMock` to :mod:`unittest.mock` that can be used to create
2+
Mock objects that can wait until they are called. Patch by Karthikeyan
3+
Singaravelan and Mario Corchero.

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy