diff --git a/Doc/library/asyncio-eventloop.rst b/Doc/library/asyncio-eventloop.rst index 4776853b5a56d8..4f0f8c06fee787 100644 --- a/Doc/library/asyncio-eventloop.rst +++ b/Doc/library/asyncio-eventloop.rst @@ -330,7 +330,7 @@ Creating Futures and Tasks .. versionadded:: 3.5.2 -.. method:: loop.create_task(coro, *, name=None) +.. method:: loop.create_task(coro, *, name=None, context=None) Schedule the execution of a :ref:`coroutine`. Return a :class:`Task` object. @@ -342,9 +342,16 @@ Creating Futures and Tasks If the *name* argument is provided and not ``None``, it is set as the name of the task using :meth:`Task.set_name`. + An optional keyword-only *context* argument allows specifying a + custom :class:`contextvars.Context` for the *coro* to run in. + The current context copy is created when no *context* is provided. + .. versionchanged:: 3.8 Added the *name* parameter. + .. versionchanged:: 3.11 + Added the *context* parameter. + .. method:: loop.set_task_factory(factory) Set a task factory that will be used by @@ -352,7 +359,7 @@ Creating Futures and Tasks If *factory* is ``None`` the default task factory will be set. Otherwise, *factory* must be a *callable* with the signature matching - ``(loop, coro)``, where *loop* is a reference to the active + ``(loop, coro, context=None)``, where *loop* is a reference to the active event loop, and *coro* is a coroutine object. The callable must return a :class:`asyncio.Future`-compatible object. diff --git a/Doc/library/asyncio-task.rst b/Doc/library/asyncio-task.rst index b30b2894277a2a..faf5910124f9b7 100644 --- a/Doc/library/asyncio-task.rst +++ b/Doc/library/asyncio-task.rst @@ -244,7 +244,7 @@ Running an asyncio Program Creating Tasks ============== -.. function:: create_task(coro, *, name=None) +.. function:: create_task(coro, *, name=None, context=None) Wrap the *coro* :ref:`coroutine ` into a :class:`Task` and schedule its execution. Return the Task object. @@ -252,6 +252,10 @@ Creating Tasks If *name* is not ``None``, it is set as the name of the task using :meth:`Task.set_name`. + An optional keyword-only *context* argument allows specifying a + custom :class:`contextvars.Context` for the *coro* to run in. + The current context copy is created when no *context* is provided. + The task is executed in the loop returned by :func:`get_running_loop`, :exc:`RuntimeError` is raised if there is no running loop in current thread. @@ -281,6 +285,9 @@ Creating Tasks .. versionchanged:: 3.8 Added the *name* parameter. + .. versionchanged:: 3.11 + Added the *context* parameter. + Sleeping ======== diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py index 51c4e664d74e9d..5eea1658df8f6f 100644 --- a/Lib/asyncio/base_events.py +++ b/Lib/asyncio/base_events.py @@ -426,18 +426,23 @@ def create_future(self): """Create a Future object attached to the loop.""" return futures.Future(loop=self) - def create_task(self, coro, *, name=None): + def create_task(self, coro, *, name=None, context=None): """Schedule a coroutine object. Return a task object. """ self._check_closed() if self._task_factory is None: - task = tasks.Task(coro, loop=self, name=name) + task = tasks.Task(coro, loop=self, name=name, context=context) if task._source_traceback: del task._source_traceback[-1] else: - task = self._task_factory(self, coro) + if context is None: + # Use legacy API if context is not needed + task = self._task_factory(self, coro) + else: + task = self._task_factory(self, coro, context=context) + tasks._set_task_name(task, name) return task diff --git a/Lib/asyncio/events.py b/Lib/asyncio/events.py index e682a192a887f2..0d26ea545baa5d 100644 --- a/Lib/asyncio/events.py +++ b/Lib/asyncio/events.py @@ -274,7 +274,7 @@ def create_future(self): # Method scheduling a coroutine object: create a task. - def create_task(self, coro, *, name=None): + def create_task(self, coro, *, name=None, context=None): raise NotImplementedError # Methods for interacting with threads. diff --git a/Lib/asyncio/taskgroups.py b/Lib/asyncio/taskgroups.py index c3ce94a4dd0a95..6af21f3a15d93a 100644 --- a/Lib/asyncio/taskgroups.py +++ b/Lib/asyncio/taskgroups.py @@ -138,12 +138,15 @@ async def __aexit__(self, et, exc, tb): me = BaseExceptionGroup('unhandled errors in a TaskGroup', errors) raise me from None - def create_task(self, coro, *, name=None): + def create_task(self, coro, *, name=None, context=None): if not self._entered: raise RuntimeError(f"TaskGroup {self!r} has not been entered") if self._exiting and self._unfinished_tasks == 0: raise RuntimeError(f"TaskGroup {self!r} is finished") - task = self._loop.create_task(coro) + if context is None: + task = self._loop.create_task(coro) + else: + task = self._loop.create_task(coro, context=context) tasks._set_task_name(task, name) task.add_done_callback(self._on_task_done) self._unfinished_tasks += 1 diff --git a/Lib/asyncio/tasks.py b/Lib/asyncio/tasks.py index e604298e5efc01..b4f1eed91a9321 100644 --- a/Lib/asyncio/tasks.py +++ b/Lib/asyncio/tasks.py @@ -93,7 +93,7 @@ class Task(futures._PyFuture): # Inherit Python Task implementation # status is still pending _log_destroy_pending = True - def __init__(self, coro, *, loop=None, name=None): + def __init__(self, coro, *, loop=None, name=None, context=None): super().__init__(loop=loop) if self._source_traceback: del self._source_traceback[-1] @@ -112,7 +112,10 @@ def __init__(self, coro, *, loop=None, name=None): self._must_cancel = False self._fut_waiter = None self._coro = coro - self._context = contextvars.copy_context() + if context is None: + self._context = contextvars.copy_context() + else: + self._context = context self._loop.call_soon(self.__step, context=self._context) _register_task(self) @@ -360,13 +363,18 @@ def __wakeup(self, future): Task = _CTask = _asyncio.Task -def create_task(coro, *, name=None): +def create_task(coro, *, name=None, context=None): """Schedule the execution of a coroutine object in a spawn task. Return a Task object. """ loop = events.get_running_loop() - task = loop.create_task(coro) + if context is None: + # Use legacy API if context is not needed + task = loop.create_task(coro) + else: + task = loop.create_task(coro, context=context) + _set_task_name(task, name) return task diff --git a/Lib/test/test_asyncio/test_taskgroups.py b/Lib/test/test_asyncio/test_taskgroups.py index df51528e107939..dea5d6de524204 100644 --- a/Lib/test/test_asyncio/test_taskgroups.py +++ b/Lib/test/test_asyncio/test_taskgroups.py @@ -2,6 +2,7 @@ import asyncio +import contextvars from asyncio import taskgroups import unittest @@ -708,6 +709,23 @@ async def coro(): t = g.create_task(coro(), name="yolo") self.assertEqual(t.get_name(), "yolo") + async def test_taskgroup_task_context(self): + cvar = contextvars.ContextVar('cvar') + + async def coro(val): + await asyncio.sleep(0) + cvar.set(val) + + async with taskgroups.TaskGroup() as g: + ctx = contextvars.copy_context() + self.assertIsNone(ctx.get(cvar)) + t1 = g.create_task(coro(1), context=ctx) + await t1 + self.assertEqual(1, ctx.get(cvar)) + t2 = g.create_task(coro(2), context=ctx) + await t2 + self.assertEqual(2, ctx.get(cvar)) + if __name__ == "__main__": unittest.main() diff --git a/Lib/test/test_asyncio/test_tasks.py b/Lib/test/test_asyncio/test_tasks.py index 95fabf728818bb..b6ef62725166dc 100644 --- a/Lib/test/test_asyncio/test_tasks.py +++ b/Lib/test/test_asyncio/test_tasks.py @@ -95,8 +95,8 @@ class BaseTaskTests: Task = None Future = None - def new_task(self, loop, coro, name='TestTask'): - return self.__class__.Task(coro, loop=loop, name=name) + def new_task(self, loop, coro, name='TestTask', context=None): + return self.__class__.Task(coro, loop=loop, name=name, context=context) def new_future(self, loop): return self.__class__.Future(loop=loop) @@ -2527,6 +2527,90 @@ async def main(): self.assertEqual(cvar.get(), -1) + def test_context_4(self): + cvar = contextvars.ContextVar('cvar') + + async def coro(val): + await asyncio.sleep(0) + cvar.set(val) + + async def main(): + ret = [] + ctx = contextvars.copy_context() + ret.append(ctx.get(cvar)) + t1 = self.new_task(loop, coro(1), context=ctx) + await t1 + ret.append(ctx.get(cvar)) + t2 = self.new_task(loop, coro(2), context=ctx) + await t2 + ret.append(ctx.get(cvar)) + return ret + + loop = asyncio.new_event_loop() + try: + task = self.new_task(loop, main()) + ret = loop.run_until_complete(task) + finally: + loop.close() + + self.assertEqual([None, 1, 2], ret) + + def test_context_5(self): + cvar = contextvars.ContextVar('cvar') + + async def coro(val): + await asyncio.sleep(0) + cvar.set(val) + + async def main(): + ret = [] + ctx = contextvars.copy_context() + ret.append(ctx.get(cvar)) + t1 = asyncio.create_task(coro(1), context=ctx) + await t1 + ret.append(ctx.get(cvar)) + t2 = asyncio.create_task(coro(2), context=ctx) + await t2 + ret.append(ctx.get(cvar)) + return ret + + loop = asyncio.new_event_loop() + try: + task = self.new_task(loop, main()) + ret = loop.run_until_complete(task) + finally: + loop.close() + + self.assertEqual([None, 1, 2], ret) + + def test_context_6(self): + cvar = contextvars.ContextVar('cvar') + + async def coro(val): + await asyncio.sleep(0) + cvar.set(val) + + async def main(): + ret = [] + ctx = contextvars.copy_context() + ret.append(ctx.get(cvar)) + t1 = loop.create_task(coro(1), context=ctx) + await t1 + ret.append(ctx.get(cvar)) + t2 = loop.create_task(coro(2), context=ctx) + await t2 + ret.append(ctx.get(cvar)) + return ret + + loop = asyncio.new_event_loop() + try: + task = loop.create_task(main()) + ret = loop.run_until_complete(task) + finally: + loop.close() + + self.assertEqual([None, 1, 2], ret) + def test_get_coro(self): loop = asyncio.new_event_loop() coro = coroutine_function() diff --git a/Lib/unittest/async_case.py b/Lib/unittest/async_case.py index 3c57bb5cda2c03..25adc3deff63d1 100644 --- a/Lib/unittest/async_case.py +++ b/Lib/unittest/async_case.py @@ -1,4 +1,5 @@ import asyncio +import contextvars import inspect import warnings @@ -34,7 +35,7 @@ class IsolatedAsyncioTestCase(TestCase): def __init__(self, methodName='runTest'): super().__init__(methodName) self._asyncioTestLoop = None - self._asyncioCallsQueue = None + self._asyncioTestContext = contextvars.copy_context() async def asyncSetUp(self): pass @@ -58,7 +59,7 @@ def addAsyncCleanup(self, func, /, *args, **kwargs): self.addCleanup(*(func, *args), **kwargs) def _callSetUp(self): - self.setUp() + self._asyncioTestContext.run(self.setUp) self._callAsync(self.asyncSetUp) def _callTestMethod(self, method): @@ -68,47 +69,30 @@ def _callTestMethod(self, method): def _callTearDown(self): self._callAsync(self.asyncTearDown) - self.tearDown() + self._asyncioTestContext.run(self.tearDown) def _callCleanup(self, function, *args, **kwargs): self._callMaybeAsync(function, *args, **kwargs) def _callAsync(self, func, /, *args, **kwargs): assert self._asyncioTestLoop is not None, 'asyncio test loop is not initialized' - ret = func(*args, **kwargs) - assert inspect.isawaitable(ret), f'{func!r} returned non-awaitable' - fut = self._asyncioTestLoop.create_future() - self._asyncioCallsQueue.put_nowait((fut, ret)) - return self._asyncioTestLoop.run_until_complete(fut) + assert inspect.iscoroutinefunction(func), f'{func!r} is not an async function' + task = self._asyncioTestLoop.create_task( + func(*args, **kwargs), + context=self._asyncioTestContext, + ) + return self._asyncioTestLoop.run_until_complete(task) def _callMaybeAsync(self, func, /, *args, **kwargs): assert self._asyncioTestLoop is not None, 'asyncio test loop is not initialized' - ret = func(*args, **kwargs) - if inspect.isawaitable(ret): - fut = self._asyncioTestLoop.create_future() - self._asyncioCallsQueue.put_nowait((fut, ret)) - return self._asyncioTestLoop.run_until_complete(fut) + if inspect.iscoroutinefunction(func): + task = self._asyncioTestLoop.create_task( + func(*args, **kwargs), + context=self._asyncioTestContext, + ) + return self._asyncioTestLoop.run_until_complete(task) else: - return ret - - async def _asyncioLoopRunner(self, fut): - self._asyncioCallsQueue = queue = asyncio.Queue() - fut.set_result(None) - while True: - query = await queue.get() - queue.task_done() - if query is None: - return - fut, awaitable = query - try: - ret = await awaitable - if not fut.cancelled(): - fut.set_result(ret) - except (SystemExit, KeyboardInterrupt): - raise - except (BaseException, asyncio.CancelledError) as ex: - if not fut.cancelled(): - fut.set_exception(ex) + return self._asyncioTestContext.run(func, *args, **kwargs) def _setupAsyncioLoop(self): assert self._asyncioTestLoop is None, 'asyncio test loop already initialized' @@ -116,16 +100,11 @@ def _setupAsyncioLoop(self): asyncio.set_event_loop(loop) loop.set_debug(True) self._asyncioTestLoop = loop - fut = loop.create_future() - self._asyncioCallsTask = loop.create_task(self._asyncioLoopRunner(fut)) - loop.run_until_complete(fut) def _tearDownAsyncioLoop(self): assert self._asyncioTestLoop is not None, 'asyncio test loop is not initialized' loop = self._asyncioTestLoop self._asyncioTestLoop = None - self._asyncioCallsQueue.put_nowait(None) - loop.run_until_complete(self._asyncioCallsQueue.join()) try: # cancel all tasks diff --git a/Lib/unittest/test/test_async_case.py b/Lib/unittest/test/test_async_case.py index 3717486b26563e..7dc8a6bffa019e 100644 --- a/Lib/unittest/test/test_async_case.py +++ b/Lib/unittest/test/test_async_case.py @@ -1,4 +1,5 @@ import asyncio +import contextvars import unittest from test import support @@ -11,6 +12,9 @@ def tearDownModule(): asyncio.set_event_loop_policy(None) +VAR = contextvars.ContextVar('VAR', default=()) + + class TestAsyncCase(unittest.TestCase): maxDiff = None @@ -24,22 +28,26 @@ class Test(unittest.IsolatedAsyncioTestCase): def setUp(self): self.assertEqual(events, []) events.append('setUp') + VAR.set(VAR.get() + ('setUp',)) async def asyncSetUp(self): self.assertEqual(events, ['setUp']) events.append('asyncSetUp') + VAR.set(VAR.get() + ('asyncSetUp',)) self.addAsyncCleanup(self.on_cleanup1) async def test_func(self): self.assertEqual(events, ['setUp', 'asyncSetUp']) events.append('test') + VAR.set(VAR.get() + ('test',)) self.addAsyncCleanup(self.on_cleanup2) async def asyncTearDown(self): self.assertEqual(events, ['setUp', 'asyncSetUp', 'test']) + VAR.set(VAR.get() + ('asyncTearDown',)) events.append('asyncTearDown') def tearDown(self): @@ -48,6 +56,7 @@ def tearDown(self): 'test', 'asyncTearDown']) events.append('tearDown') + VAR.set(VAR.get() + ('tearDown',)) async def on_cleanup1(self): self.assertEqual(events, ['setUp', @@ -57,6 +66,9 @@ async def on_cleanup1(self): 'tearDown', 'cleanup2']) events.append('cleanup1') + VAR.set(VAR.get() + ('cleanup1',)) + nonlocal cvar + cvar = VAR.get() async def on_cleanup2(self): self.assertEqual(events, ['setUp', @@ -65,8 +77,10 @@ async def on_cleanup2(self): 'asyncTearDown', 'tearDown']) events.append('cleanup2') + VAR.set(VAR.get() + ('cleanup2',)) events = [] + cvar = () test = Test("test_func") result = test.run() self.assertEqual(result.errors, []) @@ -74,13 +88,17 @@ async def on_cleanup2(self): expected = ['setUp', 'asyncSetUp', 'test', 'asyncTearDown', 'tearDown', 'cleanup2', 'cleanup1'] self.assertEqual(events, expected) + self.assertEqual(cvar, tuple(expected)) events = [] + cvar = () test = Test("test_func") test.debug() self.assertEqual(events, expected) + self.assertEqual(cvar, tuple(expected)) test.doCleanups() self.assertEqual(events, expected) + self.assertEqual(cvar, tuple(expected)) def test_exception_in_setup(self): class Test(unittest.IsolatedAsyncioTestCase): diff --git a/Misc/NEWS.d/next/Library/2022-03-12-12-34-13.bpo-46994.d7hPdz.rst b/Misc/NEWS.d/next/Library/2022-03-12-12-34-13.bpo-46994.d7hPdz.rst new file mode 100644 index 00000000000000..765936f1efb594 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2022-03-12-12-34-13.bpo-46994.d7hPdz.rst @@ -0,0 +1,2 @@ +Accept explicit contextvars.Context in :func:`asyncio.create_task` and +:meth:`asyncio.loop.create_task`. diff --git a/Modules/_asynciomodule.c b/Modules/_asynciomodule.c index 2a6c0b335ccfb0..4b12744e625e19 100644 --- a/Modules/_asynciomodule.c +++ b/Modules/_asynciomodule.c @@ -2003,14 +2003,16 @@ _asyncio.Task.__init__ * loop: object = None name: object = None + context: object = None A coroutine wrapped in a Future. [clinic start generated code]*/ static int _asyncio_Task___init___impl(TaskObj *self, PyObject *coro, PyObject *loop, - PyObject *name) -/*[clinic end generated code: output=88b12b83d570df50 input=352a3137fe60091d]*/ + PyObject *name, PyObject *context) +/*[clinic end generated code: output=49ac96fe33d0e5c7 input=924522490c8ce825]*/ + { if (future_init((FutureObj*)self, loop)) { return -1; @@ -2028,9 +2030,13 @@ _asyncio_Task___init___impl(TaskObj *self, PyObject *coro, PyObject *loop, return -1; } - Py_XSETREF(self->task_context, PyContext_CopyCurrent()); - if (self->task_context == NULL) { - return -1; + if (context == Py_None) { + Py_XSETREF(self->task_context, PyContext_CopyCurrent()); + if (self->task_context == NULL) { + return -1; + } + } else { + self->task_context = Py_NewRef(context); } Py_CLEAR(self->task_fut_waiter); diff --git a/Modules/clinic/_asynciomodule.c.h b/Modules/clinic/_asynciomodule.c.h index 2b84ef0a477c71..4a90dfa67c22b2 100644 --- a/Modules/clinic/_asynciomodule.c.h +++ b/Modules/clinic/_asynciomodule.c.h @@ -310,28 +310,29 @@ _asyncio_Future__repr_info(FutureObj *self, PyObject *Py_UNUSED(ignored)) } PyDoc_STRVAR(_asyncio_Task___init____doc__, -"Task(coro, *, loop=None, name=None)\n" +"Task(coro, *, loop=None, name=None, context=None)\n" "--\n" "\n" "A coroutine wrapped in a Future."); static int _asyncio_Task___init___impl(TaskObj *self, PyObject *coro, PyObject *loop, - PyObject *name); + PyObject *name, PyObject *context); static int _asyncio_Task___init__(PyObject *self, PyObject *args, PyObject *kwargs) { int return_value = -1; - static const char * const _keywords[] = {"coro", "loop", "name", NULL}; + static const char * const _keywords[] = {"coro", "loop", "name", "context", NULL}; static _PyArg_Parser _parser = {NULL, _keywords, "Task", 0}; - PyObject *argsbuf[3]; + PyObject *argsbuf[4]; PyObject * const *fastargs; Py_ssize_t nargs = PyTuple_GET_SIZE(args); Py_ssize_t noptargs = nargs + (kwargs ? PyDict_GET_SIZE(kwargs) : 0) - 1; PyObject *coro; PyObject *loop = Py_None; PyObject *name = Py_None; + PyObject *context = Py_None; fastargs = _PyArg_UnpackKeywords(_PyTuple_CAST(args)->ob_item, nargs, kwargs, NULL, &_parser, 1, 1, 0, argsbuf); if (!fastargs) { @@ -347,9 +348,15 @@ _asyncio_Task___init__(PyObject *self, PyObject *args, PyObject *kwargs) goto skip_optional_kwonly; } } - name = fastargs[2]; + if (fastargs[2]) { + name = fastargs[2]; + if (!--noptargs) { + goto skip_optional_kwonly; + } + } + context = fastargs[3]; skip_optional_kwonly: - return_value = _asyncio_Task___init___impl((TaskObj *)self, coro, loop, name); + return_value = _asyncio_Task___init___impl((TaskObj *)self, coro, loop, name, context); exit: return return_value; @@ -917,4 +924,4 @@ _asyncio__leave_task(PyObject *module, PyObject *const *args, Py_ssize_t nargs, exit: return return_value; } -/*[clinic end generated code: output=344927e9b6016ad7 input=a9049054013a1b77]*/ +/*[clinic end generated code: output=540ed3caf5a4d57d input=a9049054013a1b77]*/ pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy