diff --git a/Lib/test/support/interpreters.py b/Lib/test/support/interpreters.py index eeff3abe0324e5..d2beba31e80283 100644 --- a/Lib/test/support/interpreters.py +++ b/Lib/test/support/interpreters.py @@ -7,7 +7,8 @@ # aliases: from _xxsubinterpreters import is_shareable from _xxinterpchannels import ( - ChannelError, ChannelNotFoundError, ChannelEmptyError, + ChannelError, ChannelNotFoundError, ChannelClosedError, + ChannelEmptyError, ChannelNotEmptyError, ) @@ -117,10 +118,16 @@ def list_all_channels(): class _ChannelEnd: """The base class for RecvChannel and SendChannel.""" - def __init__(self, id): - if not isinstance(id, (int, _channels.ChannelID)): - raise TypeError(f'id must be an int, got {id!r}') - self._id = id + _end = None + + def __init__(self, cid): + if self._end == 'send': + cid = _channels._channel_id(cid, send=True, force=True) + elif self._end == 'recv': + cid = _channels._channel_id(cid, recv=True, force=True) + else: + raise NotImplementedError(self._end) + self._id = cid def __repr__(self): return f'{type(self).__name__}(id={int(self._id)})' @@ -147,6 +154,8 @@ def id(self): class RecvChannel(_ChannelEnd): """The receiving end of a cross-interpreter channel.""" + _end = 'recv' + def recv(self, *, _sentinel=object(), _delay=10 / 1000): # 10 milliseconds """Return the next object from the channel. @@ -171,10 +180,15 @@ def recv_nowait(self, default=_NOT_SET): else: return _channels.recv(self._id, default) + def close(self): + _channels.close(self._id, recv=True) + class SendChannel(_ChannelEnd): """The sending end of a cross-interpreter channel.""" + _end = 'send' + def send(self, obj): """Send the object (i.e. its data) to the channel's receiving end. @@ -196,3 +210,9 @@ def send_nowait(self, obj): # None. This should be fixed when channel_send_wait() is added. # See bpo-32604 and gh-19829. return _channels.send(self._id, obj) + + def close(self): + _channels.close(self._id, send=True) + + +_channels._register_end_types(SendChannel, RecvChannel) diff --git a/Lib/test/test_interpreters.py b/Lib/test/test_interpreters.py index 9c0dac7d6c61fb..8aa6d5b83604db 100644 --- a/Lib/test/test_interpreters.py +++ b/Lib/test/test_interpreters.py @@ -574,6 +574,22 @@ def test_list_all(self): after = set(interpreters.list_all_channels()) self.assertEqual(after, created) + def test_shareable(self): + rch, sch = interpreters.create_channel() + + self.assertTrue( + interpreters.is_shareable(rch)) + self.assertTrue( + interpreters.is_shareable(sch)) + + sch.send_nowait(rch) + sch.send_nowait(sch) + rch2 = rch.recv() + sch2 = rch.recv() + + self.assertEqual(rch2, rch) + self.assertEqual(sch2, sch) + class TestRecvChannelAttrs(TestBase): diff --git a/Modules/_xxinterpchannelsmodule.c b/Modules/_xxinterpchannelsmodule.c index 6096f88421a73a..d5be76f1f0e38e 100644 --- a/Modules/_xxinterpchannelsmodule.c +++ b/Modules/_xxinterpchannelsmodule.c @@ -198,6 +198,9 @@ _release_xid_data(_PyCrossInterpreterData *data, int flags) /* module state *************************************************************/ typedef struct { + PyTypeObject *send_channel_type; + PyTypeObject *recv_channel_type; + /* heap types */ PyTypeObject *ChannelIDType; @@ -218,6 +221,21 @@ get_module_state(PyObject *mod) return state; } +static module_state * +_get_current_module_state(void) +{ + PyObject *mod = _get_current_module(); + if (mod == NULL) { + // XXX import it? + PyErr_SetString(PyExc_RuntimeError, + MODULE_NAME " module not imported yet"); + return NULL; + } + module_state *state = get_module_state(mod); + Py_DECREF(mod); + return state; +} + static int traverse_module_state(module_state *state, visitproc visit, void *arg) { @@ -237,6 +255,9 @@ traverse_module_state(module_state *state, visitproc visit, void *arg) static int clear_module_state(module_state *state) { + Py_CLEAR(state->send_channel_type); + Py_CLEAR(state->recv_channel_type); + /* heap types */ if (state->ChannelIDType != NULL) { (void)_PyCrossInterpreterData_UnregisterClass(state->ChannelIDType); @@ -1529,17 +1550,20 @@ typedef struct channelid { struct channel_id_converter_data { PyObject *module; int64_t cid; + int end; }; static int channel_id_converter(PyObject *arg, void *ptr) { int64_t cid; + int end = 0; struct channel_id_converter_data *data = ptr; module_state *state = get_module_state(data->module); assert(state != NULL); if (PyObject_TypeCheck(arg, state->ChannelIDType)) { cid = ((channelid *)arg)->id; + end = ((channelid *)arg)->end; } else if (PyIndex_Check(arg)) { cid = PyLong_AsLongLong(arg); @@ -1559,6 +1583,7 @@ channel_id_converter(PyObject *arg, void *ptr) return 0; } data->cid = cid; + data->end = end; return 1; } @@ -1600,6 +1625,7 @@ _channelid_new(PyObject *mod, PyTypeObject *cls, { static char *kwlist[] = {"id", "send", "recv", "force", "_resolve", NULL}; int64_t cid; + int end; struct channel_id_converter_data cid_data = { .module = mod, }; @@ -1614,6 +1640,7 @@ _channelid_new(PyObject *mod, PyTypeObject *cls, return NULL; } cid = cid_data.cid; + end = cid_data.end; // Handle "send" and "recv". if (send == 0 && recv == 0) { @@ -1621,14 +1648,17 @@ _channelid_new(PyObject *mod, PyTypeObject *cls, "'send' and 'recv' cannot both be False"); return NULL; } - - int end = 0; - if (send == 1) { + else if (send == 1) { if (recv == 0 || recv == -1) { end = CHANNEL_SEND; } + else { + assert(recv == 1); + end = 0; + } } else if (recv == 1) { + assert(send == 0 || send == -1); end = CHANNEL_RECV; } @@ -1773,21 +1803,12 @@ channelid_richcompare(PyObject *self, PyObject *other, int op) return res; } +static PyTypeObject * _get_current_channel_end_type(int end); + static PyObject * _channel_from_cid(PyObject *cid, int end) { - PyObject *highlevel = PyImport_ImportModule("interpreters"); - if (highlevel == NULL) { - PyErr_Clear(); - highlevel = PyImport_ImportModule("test.support.interpreters"); - if (highlevel == NULL) { - return NULL; - } - } - const char *clsname = (end == CHANNEL_RECV) ? "RecvChannel" : - "SendChannel"; - PyObject *cls = PyObject_GetAttrString(highlevel, clsname); - Py_DECREF(highlevel); + PyObject *cls = (PyObject *)_get_current_channel_end_type(end); if (cls == NULL) { return NULL; } @@ -1943,6 +1964,103 @@ static PyType_Spec ChannelIDType_spec = { }; +/* SendChannel and RecvChannel classes */ + +// XXX Use a new __xid__ protocol instead? + +static PyTypeObject * +_get_current_channel_end_type(int end) +{ + module_state *state = _get_current_module_state(); + if (state == NULL) { + return NULL; + } + PyTypeObject *cls; + if (end == CHANNEL_SEND) { + cls = state->send_channel_type; + } + else { + assert(end == CHANNEL_RECV); + cls = state->recv_channel_type; + } + if (cls == NULL) { + PyObject *highlevel = PyImport_ImportModule("interpreters"); + if (highlevel == NULL) { + PyErr_Clear(); + highlevel = PyImport_ImportModule("test.support.interpreters"); + if (highlevel == NULL) { + return NULL; + } + } + if (end == CHANNEL_SEND) { + cls = state->send_channel_type; + } + else { + cls = state->recv_channel_type; + } + assert(cls != NULL); + } + return cls; +} + +static PyObject * +_channel_end_from_xid(_PyCrossInterpreterData *data) +{ + channelid *cid = (channelid *)_channelid_from_xid(data); + if (cid == NULL) { + return NULL; + } + PyTypeObject *cls = _get_current_channel_end_type(cid->end); + if (cls == NULL) { + return NULL; + } + PyObject *obj = PyObject_CallOneArg((PyObject *)cls, (PyObject *)cid); + Py_DECREF(cid); + return obj; +} + +static int +_channel_end_shared(PyThreadState *tstate, PyObject *obj, + _PyCrossInterpreterData *data) +{ + PyObject *cidobj = PyObject_GetAttrString(obj, "_id"); + if (cidobj == NULL) { + return -1; + } + if (_channelid_shared(tstate, cidobj, data) < 0) { + return -1; + } + data->new_object = _channel_end_from_xid; + return 0; +} + +static int +set_channel_end_types(PyObject *mod, PyTypeObject *send, PyTypeObject *recv) +{ + module_state *state = get_module_state(mod); + if (state == NULL) { + return -1; + } + + if (state->send_channel_type != NULL + || state->recv_channel_type != NULL) + { + PyErr_SetString(PyExc_TypeError, "already registered"); + return -1; + } + state->send_channel_type = (PyTypeObject *)Py_NewRef(send); + state->recv_channel_type = (PyTypeObject *)Py_NewRef(recv); + + if (_PyCrossInterpreterData_RegisterClass(send, _channel_end_shared)) { + return -1; + } + if (_PyCrossInterpreterData_RegisterClass(recv, _channel_end_shared)) { + return -1; + } + + return 0; +} + /* module level code ********************************************************/ /* globals is the process-global state for the module. It holds all @@ -2346,13 +2464,38 @@ channel__channel_id(PyObject *self, PyObject *args, PyObject *kwds) return NULL; } PyTypeObject *cls = state->ChannelIDType; - PyObject *mod = get_module_from_owned_type(cls); - if (mod == NULL) { + assert(get_module_from_owned_type(cls) == self); + + return _channelid_new(self, cls, args, kwds); +} + +static PyObject * +channel__register_end_types(PyObject *self, PyObject *args, PyObject *kwds) +{ + static char *kwlist[] = {"send", "recv", NULL}; + PyObject *send; + PyObject *recv; + if (!PyArg_ParseTupleAndKeywords(args, kwds, + "OO:_register_end_types", kwlist, + &send, &recv)) { return NULL; } - PyObject *cid = _channelid_new(mod, cls, args, kwds); - Py_DECREF(mod); - return cid; + if (!PyType_Check(send)) { + PyErr_SetString(PyExc_TypeError, "expected a type for 'send'"); + return NULL; + } + if (!PyType_Check(recv)) { + PyErr_SetString(PyExc_TypeError, "expected a type for 'recv'"); + return NULL; + } + PyTypeObject *cls_send = (PyTypeObject *)send; + PyTypeObject *cls_recv = (PyTypeObject *)recv; + + if (set_channel_end_types(self, cls_send, cls_recv) < 0) { + return NULL; + } + + Py_RETURN_NONE; } static PyMethodDef module_functions[] = { @@ -2374,6 +2517,8 @@ static PyMethodDef module_functions[] = { METH_VARARGS | METH_KEYWORDS, channel_release_doc}, {"_channel_id", _PyCFunction_CAST(channel__channel_id), METH_VARARGS | METH_KEYWORDS, NULL}, + {"_register_end_types", _PyCFunction_CAST(channel__register_end_types), + METH_VARARGS | METH_KEYWORDS, NULL}, {NULL, NULL} /* sentinel */ };
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: