Skip to content

Commit 964c29d

Browse files
miss-islingtonyoneympage
authored
[3.14] gh-116738: Make _heapq module thread-safe (GH-135036) (gh-135309)
Use critical sections to make heapq methods that update the heap thread-safe when the GIL is disabled. (cherry picked from commit a58026a) Co-authored-by: Alper <alperyoney@fb.com> Co-authored-by: mpage <mpage@meta.com>
1 parent 15f7bd4 commit 964c29d

File tree

4 files changed

+303
-15
lines changed

4 files changed

+303
-15
lines changed
Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
import unittest
2+
3+
import heapq
4+
5+
from enum import Enum
6+
from threading import Thread, Barrier
7+
from random import shuffle, randint
8+
9+
from test.support import threading_helper
10+
from test import test_heapq
11+
12+
13+
NTHREADS = 10
14+
OBJECT_COUNT = 5_000
15+
16+
17+
class Heap(Enum):
18+
MIN = 1
19+
MAX = 2
20+
21+
22+
@threading_helper.requires_working_threading()
23+
class TestHeapq(unittest.TestCase):
24+
def setUp(self):
25+
self.test_heapq = test_heapq.TestHeapPython()
26+
27+
def test_racing_heapify(self):
28+
heap = list(range(OBJECT_COUNT))
29+
shuffle(heap)
30+
31+
self.run_concurrently(
32+
worker_func=heapq.heapify, args=(heap,), nthreads=NTHREADS
33+
)
34+
self.test_heapq.check_invariant(heap)
35+
36+
def test_racing_heappush(self):
37+
heap = []
38+
39+
def heappush_func(heap):
40+
for item in reversed(range(OBJECT_COUNT)):
41+
heapq.heappush(heap, item)
42+
43+
self.run_concurrently(
44+
worker_func=heappush_func, args=(heap,), nthreads=NTHREADS
45+
)
46+
self.test_heapq.check_invariant(heap)
47+
48+
def test_racing_heappop(self):
49+
heap = self.create_heap(OBJECT_COUNT, Heap.MIN)
50+
51+
# Each thread pops (OBJECT_COUNT / NTHREADS) items
52+
self.assertEqual(OBJECT_COUNT % NTHREADS, 0)
53+
per_thread_pop_count = OBJECT_COUNT // NTHREADS
54+
55+
def heappop_func(heap, pop_count):
56+
local_list = []
57+
for _ in range(pop_count):
58+
item = heapq.heappop(heap)
59+
local_list.append(item)
60+
61+
# Each local list should be sorted
62+
self.assertTrue(self.is_sorted_ascending(local_list))
63+
64+
self.run_concurrently(
65+
worker_func=heappop_func,
66+
args=(heap, per_thread_pop_count),
67+
nthreads=NTHREADS,
68+
)
69+
self.assertEqual(len(heap), 0)
70+
71+
def test_racing_heappushpop(self):
72+
heap = self.create_heap(OBJECT_COUNT, Heap.MIN)
73+
pushpop_items = self.create_random_list(-5_000, 10_000, OBJECT_COUNT)
74+
75+
def heappushpop_func(heap, pushpop_items):
76+
for item in pushpop_items:
77+
popped_item = heapq.heappushpop(heap, item)
78+
self.assertTrue(popped_item <= item)
79+
80+
self.run_concurrently(
81+
worker_func=heappushpop_func,
82+
args=(heap, pushpop_items),
83+
nthreads=NTHREADS,
84+
)
85+
self.assertEqual(len(heap), OBJECT_COUNT)
86+
self.test_heapq.check_invariant(heap)
87+
88+
def test_racing_heapreplace(self):
89+
heap = self.create_heap(OBJECT_COUNT, Heap.MIN)
90+
replace_items = self.create_random_list(-5_000, 10_000, OBJECT_COUNT)
91+
92+
def heapreplace_func(heap, replace_items):
93+
for item in replace_items:
94+
heapq.heapreplace(heap, item)
95+
96+
self.run_concurrently(
97+
worker_func=heapreplace_func,
98+
args=(heap, replace_items),
99+
nthreads=NTHREADS,
100+
)
101+
self.assertEqual(len(heap), OBJECT_COUNT)
102+
self.test_heapq.check_invariant(heap)
103+
104+
def test_racing_heapify_max(self):
105+
max_heap = list(range(OBJECT_COUNT))
106+
shuffle(max_heap)
107+
108+
self.run_concurrently(
109+
worker_func=heapq.heapify_max, args=(max_heap,), nthreads=NTHREADS
110+
)
111+
self.test_heapq.check_max_invariant(max_heap)
112+
113+
def test_racing_heappush_max(self):
114+
max_heap = []
115+
116+
def heappush_max_func(max_heap):
117+
for item in range(OBJECT_COUNT):
118+
heapq.heappush_max(max_heap, item)
119+
120+
self.run_concurrently(
121+
worker_func=heappush_max_func, args=(max_heap,), nthreads=NTHREADS
122+
)
123+
self.test_heapq.check_max_invariant(max_heap)
124+
125+
def test_racing_heappop_max(self):
126+
max_heap = self.create_heap(OBJECT_COUNT, Heap.MAX)
127+
128+
# Each thread pops (OBJECT_COUNT / NTHREADS) items
129+
self.assertEqual(OBJECT_COUNT % NTHREADS, 0)
130+
per_thread_pop_count = OBJECT_COUNT // NTHREADS
131+
132+
def heappop_max_func(max_heap, pop_count):
133+
local_list = []
134+
for _ in range(pop_count):
135+
item = heapq.heappop_max(max_heap)
136+
local_list.append(item)
137+
138+
# Each local list should be sorted
139+
self.assertTrue(self.is_sorted_descending(local_list))
140+
141+
self.run_concurrently(
142+
worker_func=heappop_max_func,
143+
args=(max_heap, per_thread_pop_count),
144+
nthreads=NTHREADS,
145+
)
146+
self.assertEqual(len(max_heap), 0)
147+
148+
def test_racing_heappushpop_max(self):
149+
max_heap = self.create_heap(OBJECT_COUNT, Heap.MAX)
150+
pushpop_items = self.create_random_list(-5_000, 10_000, OBJECT_COUNT)
151+
152+
def heappushpop_max_func(max_heap, pushpop_items):
153+
for item in pushpop_items:
154+
popped_item = heapq.heappushpop_max(max_heap, item)
155+
self.assertTrue(popped_item >= item)
156+
157+
self.run_concurrently(
158+
worker_func=heappushpop_max_func,
159+
args=(max_heap, pushpop_items),
160+
nthreads=NTHREADS,
161+
)
162+
self.assertEqual(len(max_heap), OBJECT_COUNT)
163+
self.test_heapq.check_max_invariant(max_heap)
164+
165+
def test_racing_heapreplace_max(self):
166+
max_heap = self.create_heap(OBJECT_COUNT, Heap.MAX)
167+
replace_items = self.create_random_list(-5_000, 10_000, OBJECT_COUNT)
168+
169+
def heapreplace_max_func(max_heap, replace_items):
170+
for item in replace_items:
171+
heapq.heapreplace_max(max_heap, item)
172+
173+
self.run_concurrently(
174+
worker_func=heapreplace_max_func,
175+
args=(max_heap, replace_items),
176+
nthreads=NTHREADS,
177+
)
178+
self.assertEqual(len(max_heap), OBJECT_COUNT)
179+
self.test_heapq.check_max_invariant(max_heap)
180+
181+
@staticmethod
182+
def is_sorted_ascending(lst):
183+
"""
184+
Check if the list is sorted in ascending order (non-decreasing).
185+
"""
186+
return all(lst[i - 1] <= lst[i] for i in range(1, len(lst)))
187+
188+
@staticmethod
189+
def is_sorted_descending(lst):
190+
"""
191+
Check if the list is sorted in descending order (non-increasing).
192+
"""
193+
return all(lst[i - 1] >= lst[i] for i in range(1, len(lst)))
194+
195+
@staticmethod
196+
def create_heap(size, heap_kind):
197+
"""
198+
Create a min/max heap where elements are in the range (0, size - 1) and
199+
shuffled before heapify.
200+
"""
201+
heap = list(range(OBJECT_COUNT))
202+
shuffle(heap)
203+
if heap_kind == Heap.MIN:
204+
heapq.heapify(heap)
205+
else:
206+
heapq.heapify_max(heap)
207+
208+
return heap
209+
210+
@staticmethod
211+
def create_random_list(a, b, size):
212+
"""
213+
Create a list of random numbers between a and b (inclusive).
214+
"""
215+
return [randint(-a, b) for _ in range(size)]
216+
217+
def run_concurrently(self, worker_func, args, nthreads):
218+
"""
219+
Run the worker function concurrently in multiple threads.
220+
"""
221+
barrier = Barrier(nthreads)
222+
223+
def wrapper_func(*args):
224+
# Wait for all threads to reach this point before proceeding.
225+
barrier.wait()
226+
worker_func(*args)
227+
228+
with threading_helper.catch_threading_exception() as cm:
229+
workers = (
230+
Thread(target=wrapper_func, args=args) for _ in range(nthreads)
231+
)
232+
with threading_helper.start_threads(workers):
233+
pass
234+
235+
# Worker threads should not raise any exceptions
236+
self.assertIsNone(cm.exc_value)
237+
238+
239+
if __name__ == "__main__":
240+
unittest.main()
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Make methods in :mod:`heapq` thread-safe on the :term:`free threaded <free threading>` build.

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy