diff --git a/Lib/test/test_ctypes/test_pointers.py b/Lib/test/test_ctypes/test_pointers.py index a8d243a45de0f4..792a3687dba93b 100644 --- a/Lib/test/test_ctypes/test_pointers.py +++ b/Lib/test/test_ctypes/test_pointers.py @@ -5,7 +5,7 @@ import unittest from ctypes import (CDLL, CFUNCTYPE, Structure, POINTER, pointer, _Pointer, - byref, sizeof, + addressof, byref, sizeof, c_void_p, c_char_p, c_byte, c_ubyte, c_short, c_ushort, c_int, c_uint, c_long, c_ulong, c_longlong, c_ulonglong, @@ -472,6 +472,105 @@ class C(Structure): ptr.set_type(c_int) self.assertIs(ptr._type_, c_int) + def test_pointer_lifecycle_basic(self): + i = c_long(1010) + p = pointer(i) + self.assertEqual(p[0], 1010) + self.assertIsNone(p._b_base_) + self.assertEqual(addressof(i), addressof(p.contents)) + + def test_pointer_lifecycle_set_contents(self): + i = c_long(2020) + p = pointer(c_long(1010)) + p.contents = i + self.assertEqual(p[0], 2020) + self.assertIsNone(p._b_base_) + self.assertEqual(addressof(i), addressof(p.contents)) + + def test_pointer_lifecycle_set_pointer_contents(self): + i = c_long(3030) + p = pointer(c_long(1010)) + pointer(p).contents.contents = i + self.assertEqual(p.contents.value, 3030) + self.assertEqual(addressof(i), addressof(p.contents)) + + def test_pointer_lifecycle_array_set_contents(self): + arr_type = POINTER(c_long) * 3 + arr_obj = arr_type() + i = c_long(300300) + arr_obj[0] = pointer(c_long(100100)) + arr_obj[1] = pointer(c_long(200200)) + arr_obj[2] = pointer(i) + self.assertEqual(arr_obj[0].contents.value, 100100) + self.assertEqual(arr_obj[1].contents.value, 200200) + self.assertEqual(arr_obj[2].contents.value, 300300) + self.assertEqual(addressof(i), addressof(arr_obj[2].contents)) + + def test_pointer_lifecycle_array_set_pointer_contents(self): + arr_type = POINTER(c_long) * 3 + arr_obj = arr_type() + i = c_long(200003) + arr_obj[0].contents = c_long(100001) + arr_obj[1].contents = c_long(200002) + arr_obj[2].contents = i + self.assertEqual(arr_obj[0].contents.value, 100001) + self.assertEqual(arr_obj[1].contents.value, 200002) + self.assertEqual(arr_obj[2].contents.value, 200003) + self.assertEqual(addressof(i), addressof(arr_obj[2].contents)) + + def test_pointer_lifecycle_array_set_pointer_contents_pointer(self): + arr_type = POINTER(c_long) * 3 + arr_obj = arr_type() + i = c_long(200003) + pointer(arr_obj[0]).contents.contents = c_long(100001) + pointer(arr_obj[1]).contents.contents = c_long(200002) + pointer(arr_obj[2]).contents.contents = i + self.assertEqual(arr_obj[0].contents.value, 100001) + self.assertEqual(arr_obj[1].contents.value, 200002) + self.assertEqual(arr_obj[2].contents.value, 200003) + self.assertEqual(addressof(i), addressof(arr_obj[2].contents)) + + def test_pointer_lifecycle_struct_set_contents(self): + class S(Structure): + _fields_ = (("s", POINTER(c_long)),) + s = S(s=pointer(c_long(1111111))) + s.s.contents = c_long(2222222) + self.assertEqual(s.s.contents.value, 2222222) + + def test_pointer_lifecycle_struct_set_contents_pointer(self): + class S(Structure): + _fields_ = (("s", POINTER(c_long)),) + s = S(s=pointer(c_long(1111111))) + pointer(s.s).contents.contents = c_long(2222222) + self.assertEqual(s.s.contents.value, 2222222) + + def test_pointer_lifecycle_struct_set_pointer_contents(self): + class S(Structure): + _fields_ = (("s", POINTER(c_long)),) + s = S(s=pointer(c_long(1111111))) + s.s = pointer(c_long(3333333)) + self.assertEqual(s.s.contents.value, 3333333) + + def test_pointer_lifecycle_struct_with_extra_field(self): + class U(Structure): + _fields_ = ( + ("s", POINTER(c_long)), + ("u", c_long), + ) + u = U(s=pointer(c_long(1010101))) + u.s.contents = c_long(202020202) + self.assertEqual(u.s.contents.value, 202020202) + + def test_pointer_lifecycle_struct_with_extra_field_pointer(self): + class U(Structure): + _fields_ = ( + ("s", POINTER(c_uint)), + ("u", c_uint), + ) + u = U(s=pointer(c_uint(1010101))) + pointer(u.s).contents.contents = c_uint(202020202) + self.assertEqual(u.s.contents.value, 202020202) + if __name__ == '__main__': unittest.main() diff --git a/Modules/_ctypes/_ctypes.c b/Modules/_ctypes/_ctypes.c index 7e8a133caa72ac..ccf438999fd082 100644 --- a/Modules/_ctypes/_ctypes.c +++ b/Modules/_ctypes/_ctypes.c @@ -2958,7 +2958,7 @@ KeepRef_lock_held(CDataObject *target, Py_ssize_t index, PyObject *keep) CDataObject *ob; PyObject *key; -/* Optimization: no need to store None */ + /* Optimization: no need to store None */ if (keep == Py_None) { Py_DECREF(Py_None); return 0; @@ -5718,8 +5718,34 @@ Pointer_set_contents_lock_held(PyObject *op, PyObject *value, void *closure) pointer instance has b_length set to 2 instead of 1, and we set 'value' itself as the second item of the b_objects list, additionally. */ + + CDataObject * root = self->b_base; + /* perhaps, this is a bit excessive: if we have are in a chain of pointers + that starts with non-pointer (e.g. a union), can we consider the current + pointer to be "detached" from this chain? */ + while (root != NULL && root->b_base != NULL) { + root = root->b_base; + } + + /* If the b_base is NULL now or if we are a part of chain of pointers fully + modeled within ctypes, AND the value is a pointer, array, struct or union, + we just override the b_base. */ + if ((root == NULL || PyType_IsSubtype(Py_TYPE(root), st->PyCPointer_Type)) && + (PyType_IsSubtype(Py_TYPE(value), st->PyCPointer_Type) || + PyType_IsSubtype(Py_TYPE(value), st->PyCArray_Type) || + PyType_IsSubtype(Py_TYPE(value), st->Struct_Type) || + PyType_IsSubtype(Py_TYPE(value), st->Union_Type)) + ) { + Py_XSETREF(self->b_base, (CDataObject *) Py_NewRef(value)); + return 0; // no need to add `value` to `keep` objects - it's in b_base + } + + /* If we are a part of chain of pointers that is not fully modeled within + ctypes, (or modeled in a complex way, e.g., with arrays and structures), + then everything should be covered by keepref logic bellow */ + Py_INCREF(value); - if (-1 == KeepRef(self, 1, value)) + if (-1 == KeepRef_lock_held(self, 1, value)) return -1; keep = GetKeepedObjects(dst); @@ -5727,7 +5753,7 @@ Pointer_set_contents_lock_held(PyObject *op, PyObject *value, void *closure) return -1; Py_INCREF(keep); - return KeepRef(self, 0, keep); + return KeepRef_lock_held(self, 0, keep); } static int
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: