diff --git a/Lib/test/test__xxsubinterpreters.py b/Lib/test/test__xxsubinterpreters.py index 118f2e4895fe12..f66cc95169260d 100644 --- a/Lib/test/test__xxsubinterpreters.py +++ b/Lib/test/test__xxsubinterpreters.py @@ -1379,12 +1379,104 @@ def test_close_multiple_times(self): with self.assertRaises(interpreters.ChannelClosedError): interpreters.channel_close(cid) - def test_close_with_unused_items(self): + def test_close_empty(self): + tests = [ + (False, False), + (True, False), + (False, True), + (True, True), + ] + for send, recv in tests: + with self.subTest((send, recv)): + cid = interpreters.channel_create() + interpreters.channel_send(cid, b'spam') + interpreters.channel_recv(cid) + interpreters.channel_close(cid, send=send, recv=recv) + + with self.assertRaises(interpreters.ChannelClosedError): + interpreters.channel_send(cid, b'eggs') + with self.assertRaises(interpreters.ChannelClosedError): + interpreters.channel_recv(cid) + + def test_close_defaults_with_unused_items(self): + cid = interpreters.channel_create() + interpreters.channel_send(cid, b'spam') + interpreters.channel_send(cid, b'ham') + + with self.assertRaises(interpreters.ChannelNotEmptyError): + interpreters.channel_close(cid) + interpreters.channel_recv(cid) + interpreters.channel_send(cid, b'eggs') + + def test_close_recv_with_unused_items_unforced(self): cid = interpreters.channel_create() interpreters.channel_send(cid, b'spam') interpreters.channel_send(cid, b'ham') - interpreters.channel_close(cid) + + with self.assertRaises(interpreters.ChannelNotEmptyError): + interpreters.channel_close(cid, recv=True) + interpreters.channel_recv(cid) + interpreters.channel_send(cid, b'eggs') + interpreters.channel_recv(cid) + interpreters.channel_recv(cid) + interpreters.channel_close(cid, recv=True) + + def test_close_send_with_unused_items_unforced(self): + cid = interpreters.channel_create() + interpreters.channel_send(cid, b'spam') + interpreters.channel_send(cid, b'ham') + interpreters.channel_close(cid, send=True) + with self.assertRaises(interpreters.ChannelClosedError): + interpreters.channel_send(cid, b'eggs') + interpreters.channel_recv(cid) + interpreters.channel_recv(cid) + with self.assertRaises(interpreters.ChannelClosedError): + interpreters.channel_recv(cid) + + def test_close_both_with_unused_items_unforced(self): + cid = interpreters.channel_create() + interpreters.channel_send(cid, b'spam') + interpreters.channel_send(cid, b'ham') + + with self.assertRaises(interpreters.ChannelNotEmptyError): + interpreters.channel_close(cid, recv=True, send=True) + interpreters.channel_recv(cid) + interpreters.channel_send(cid, b'eggs') + interpreters.channel_recv(cid) + interpreters.channel_recv(cid) + interpreters.channel_close(cid, recv=True) + + def test_close_recv_with_unused_items_forced(self): + cid = interpreters.channel_create() + interpreters.channel_send(cid, b'spam') + interpreters.channel_send(cid, b'ham') + interpreters.channel_close(cid, recv=True, force=True) + + with self.assertRaises(interpreters.ChannelClosedError): + interpreters.channel_send(cid, b'eggs') + with self.assertRaises(interpreters.ChannelClosedError): + interpreters.channel_recv(cid) + + def test_close_send_with_unused_items_forced(self): + cid = interpreters.channel_create() + interpreters.channel_send(cid, b'spam') + interpreters.channel_send(cid, b'ham') + interpreters.channel_close(cid, send=True, force=True) + + with self.assertRaises(interpreters.ChannelClosedError): + interpreters.channel_send(cid, b'eggs') + with self.assertRaises(interpreters.ChannelClosedError): + interpreters.channel_recv(cid) + + def test_close_both_with_unused_items_forced(self): + cid = interpreters.channel_create() + interpreters.channel_send(cid, b'spam') + interpreters.channel_send(cid, b'ham') + interpreters.channel_close(cid, send=True, recv=True, force=True) + + with self.assertRaises(interpreters.ChannelClosedError): + interpreters.channel_send(cid, b'eggs') with self.assertRaises(interpreters.ChannelClosedError): interpreters.channel_recv(cid) @@ -1403,7 +1495,7 @@ def test_close_by_unassociated_interp(self): interp = interpreters.create() interpreters.run_string(interp, dedent(f""" import _xxsubinterpreters as _interpreters - _interpreters.channel_close({cid}) + _interpreters.channel_close({cid}, force=True) """)) with self.assertRaises(interpreters.ChannelClosedError): interpreters.channel_recv(cid) @@ -1416,7 +1508,7 @@ def test_close_used_multiple_times_by_single_user(self): interpreters.channel_send(cid, b'spam') interpreters.channel_send(cid, b'spam') interpreters.channel_recv(cid) - interpreters.channel_close(cid) + interpreters.channel_close(cid, force=True) with self.assertRaises(interpreters.ChannelClosedError): interpreters.channel_send(cid, b'eggs') diff --git a/Modules/_xxsubinterpretersmodule.c b/Modules/_xxsubinterpretersmodule.c index 5184f6593db15e..72387d8da56b51 100644 --- a/Modules/_xxsubinterpretersmodule.c +++ b/Modules/_xxsubinterpretersmodule.c @@ -306,10 +306,15 @@ _sharedexception_apply(_sharedexception *exc, PyObject *wrapperclass) /* channel-specific code ****************************************************/ +#define CHANNEL_SEND 1 +#define CHANNEL_BOTH 0 +#define CHANNEL_RECV -1 + static PyObject *ChannelError; static PyObject *ChannelNotFoundError; static PyObject *ChannelClosedError; static PyObject *ChannelEmptyError; +static PyObject *ChannelNotEmptyError; static int channel_exceptions_init(PyObject *ns) @@ -356,6 +361,16 @@ channel_exceptions_init(PyObject *ns) return -1; } + // An operation tried to close a non-empty channel. + ChannelNotEmptyError = PyErr_NewException( + "_xxsubinterpreters.ChannelNotEmptyError", ChannelError, NULL); + if (ChannelNotEmptyError == NULL) { + return -1; + } + if (PyDict_SetItemString(ns, "ChannelNotEmptyError", ChannelNotEmptyError) != 0) { + return -1; + } + return 0; } @@ -696,8 +711,11 @@ _channelends_close_interpreter(_channelends *ends, int64_t interp, int which) } static void -_channelends_close_all(_channelends *ends) +_channelends_close_all(_channelends *ends, int which, int force) { + // XXX Handle the ends. + // XXX Handle force is True. + // Ensure all the "send"-associated interpreters are closed. _channelend *end; for (end = ends->send; end != NULL; end = end->next) { @@ -713,12 +731,16 @@ _channelends_close_all(_channelends *ends) /* channels */ struct _channel; +struct _channel_closing; +static void _channel_clear_closing(struct _channel *); +static void _channel_finish_closing(struct _channel *); typedef struct _channel { PyThread_type_lock mutex; _channelqueue *queue; _channelends *ends; int open; + struct _channel_closing *closing; } _PyChannelState; static _PyChannelState * @@ -747,12 +769,14 @@ _channel_new(void) return NULL; } chan->open = 1; + chan->closing = NULL; return chan; } static void _channel_free(_PyChannelState *chan) { + _channel_clear_closing(chan); PyThread_acquire_lock(chan->mutex, WAIT_LOCK); _channelqueue_free(chan->queue); _channelends_free(chan->ends); @@ -802,13 +826,20 @@ _channel_next(_PyChannelState *chan, int64_t interp) } data = _channelqueue_get(chan->queue); + if (data == NULL && !PyErr_Occurred() && chan->closing != NULL) { + chan->open = 0; + } + done: PyThread_release_lock(chan->mutex); + if (chan->queue->count == 0) { + _channel_finish_closing(chan); + } return data; } static int -_channel_close_interpreter(_PyChannelState *chan, int64_t interp, int which) +_channel_close_interpreter(_PyChannelState *chan, int64_t interp, int end) { PyThread_acquire_lock(chan->mutex, WAIT_LOCK); @@ -818,7 +849,7 @@ _channel_close_interpreter(_PyChannelState *chan, int64_t interp, int which) goto done; } - if (_channelends_close_interpreter(chan->ends, interp, which) != 0) { + if (_channelends_close_interpreter(chan->ends, interp, end) != 0) { goto done; } chan->open = _channelends_is_open(chan->ends); @@ -830,7 +861,7 @@ _channel_close_interpreter(_PyChannelState *chan, int64_t interp, int which) } static int -_channel_close_all(_PyChannelState *chan) +_channel_close_all(_PyChannelState *chan, int end, int force) { int res = -1; PyThread_acquire_lock(chan->mutex, WAIT_LOCK); @@ -840,11 +871,17 @@ _channel_close_all(_PyChannelState *chan) goto done; } + if (!force && chan->queue->count > 0) { + PyErr_SetString(ChannelNotEmptyError, + "may not be closed if not empty (try force=True)"); + goto done; + } + chan->open = 0; // We *could* also just leave these in place, since we've marked // the channel as closed already. - _channelends_close_all(chan->ends); + _channelends_close_all(chan->ends, end, force); res = 0; done: @@ -889,6 +926,9 @@ _channelref_new(int64_t id, _PyChannelState *chan) static void _channelref_free(_channelref *ref) { + if (ref->chan != NULL) { + _channel_clear_closing(ref->chan); + } //_channelref_clear(ref); PyMem_Free(ref); } @@ -1009,8 +1049,12 @@ _channels_add(_channels *channels, _PyChannelState *chan) return cid; } +/* forward */ +static int _channel_set_closing(struct _channelref *, PyThread_type_lock); + static int -_channels_close(_channels *channels, int64_t cid, _PyChannelState **pchan) +_channels_close(_channels *channels, int64_t cid, _PyChannelState **pchan, + int end, int force) { int res = -1; PyThread_acquire_lock(channels->mutex, WAIT_LOCK); @@ -1028,14 +1072,35 @@ _channels_close(_channels *channels, int64_t cid, _PyChannelState **pchan) PyErr_Format(ChannelClosedError, "channel %d closed", cid); goto done; } + else if (!force && end == CHANNEL_SEND && ref->chan->closing != NULL) { + PyErr_Format(ChannelClosedError, "channel %d closed", cid); + goto done; + } else { - if (_channel_close_all(ref->chan) != 0) { + if (_channel_close_all(ref->chan, end, force) != 0) { + if (end == CHANNEL_SEND && + PyErr_ExceptionMatches(ChannelNotEmptyError)) { + if (ref->chan->closing != NULL) { + PyErr_Format(ChannelClosedError, "channel %d closed", cid); + goto done; + } + // Mark the channel as closing and return. The channel + // will be cleaned up in _channel_next(). + PyErr_Clear(); + if (_channel_set_closing(ref, channels->mutex) != 0) { + goto done; + } + if (pchan != NULL) { + *pchan = ref->chan; + } + res = 0; + } goto done; } if (pchan != NULL) { *pchan = ref->chan; } - else { + else { _channel_free(ref->chan); } ref->chan = NULL; @@ -1161,6 +1226,60 @@ _channels_list_all(_channels *channels, int64_t *count) return cids; } +/* support for closing non-empty channels */ + +struct _channel_closing { + struct _channelref *ref; +}; + +static int +_channel_set_closing(struct _channelref *ref, PyThread_type_lock mutex) { + struct _channel *chan = ref->chan; + if (chan == NULL) { + // already closed + return 0; + } + int res = -1; + PyThread_acquire_lock(chan->mutex, WAIT_LOCK); + if (chan->closing != NULL) { + PyErr_SetString(ChannelClosedError, "channel closed"); + goto done; + } + chan->closing = PyMem_NEW(struct _channel_closing, 1); + if (chan->closing == NULL) { + goto done; + } + chan->closing->ref = ref; + + res = 0; +done: + PyThread_release_lock(chan->mutex); + return res; +} + +static void +_channel_clear_closing(struct _channel *chan) { + PyThread_acquire_lock(chan->mutex, WAIT_LOCK); + if (chan->closing != NULL) { + PyMem_Free(chan->closing); + chan->closing = NULL; + } + PyThread_release_lock(chan->mutex); +} + +static void +_channel_finish_closing(struct _channel *chan) { + struct _channel_closing *closing = chan->closing; + if (closing == NULL) { + return; + } + _channelref *ref = closing->ref; + _channel_clear_closing(chan); + // Do the things that would have been done in _channels_close(). + ref->chan = NULL; + _channel_free(chan); +}; + /* "high"-level channel-related functions */ static int64_t @@ -1207,6 +1326,12 @@ _channel_send(_channels *channels, int64_t id, PyObject *obj) } // Past this point we are responsible for releasing the mutex. + if (chan->closing != NULL) { + PyErr_Format(ChannelClosedError, "channel %d closed", id); + PyThread_release_lock(mutex); + return -1; + } + // Convert the object to cross-interpreter data. _PyCrossInterpreterData *data = PyMem_NEW(_PyCrossInterpreterData, 1); if (data == NULL) { @@ -1290,16 +1415,13 @@ _channel_drop(_channels *channels, int64_t id, int send, int recv) } static int -_channel_close(_channels *channels, int64_t id) +_channel_close(_channels *channels, int64_t id, int end, int force) { - return _channels_close(channels, id, NULL); + return _channels_close(channels, id, NULL, end, force); } /* ChannelID class */ -#define CHANNEL_SEND 1 -#define CHANNEL_RECV -1 - static PyTypeObject ChannelIDtype; typedef struct channelid { @@ -2555,15 +2677,8 @@ channel_close(PyObject *self, PyObject *args, PyObject *kwds) if (cid < 0) { return NULL; } - if (send == 0 && recv == 0) { - send = 1; - recv = 1; - } - - // XXX Handle the ends. - // XXX Handle force is True. - if (_channel_close(&_globals.channels, cid) != 0) { + if (_channel_close(&_globals.channels, cid, send-recv, force) != 0) { return NULL; } Py_RETURN_NONE; pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy