Skip to content

Commit db87e2b

Browse files
authored
Replace launching explicit threads with a threadpoolexecutor
Differential Revision: D72825598 Pull Request resolved: pytorch#1469
1 parent ac98992 commit db87e2b

File tree

2 files changed

+122
-24
lines changed

2 files changed

+122
-24
lines changed

test/nodes/test_map.py

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import itertools
8-
98
import unittest
109
from typing import List, Optional
10+
from unittest import mock
1111

1212
from parameterized import parameterized
1313
from torch.testing._internal.common_utils import IS_WINDOWS, TEST_CUDA, TestCase
@@ -131,7 +131,11 @@ def test_out_of_order_process_prebatch(self):
131131
)
132132
)
133133
def test_save_load_state_thread(
134-
self, midpoint: int, in_order: bool, snapshot_frequency: int, prebatch: Optional[int]
134+
self,
135+
midpoint: int,
136+
in_order: bool,
137+
snapshot_frequency: int,
138+
prebatch: Optional[int],
135139
):
136140
method = "thread"
137141
batch_size = 6
@@ -159,7 +163,11 @@ def test_save_load_state_thread(
159163
)
160164
)
161165
def test_save_load_state_process(
162-
self, midpoint: int, in_order: bool, snapshot_frequency: int, prebatch: Optional[int]
166+
self,
167+
midpoint: int,
168+
in_order: bool,
169+
snapshot_frequency: int,
170+
prebatch: Optional[int],
163171
):
164172
method = "process"
165173
batch_size = 6
@@ -179,3 +187,56 @@ def test_save_load_state_process(
179187
)
180188
node = Prefetcher(node, prefetch_factor=2)
181189
run_test_save_load_state(self, node, midpoint)
190+
191+
def test_thread_pool_executor_shutdown_on_del(self):
192+
"""Test that the ThreadPoolExecutor is properly shut down when the iterator is deleted."""
193+
# Create a ParallelMapper with method="thread"
194+
src = MockSource(num_samples=10)
195+
node = ParallelMapper(
196+
src,
197+
RandomSleepUdf(),
198+
num_workers=2,
199+
method="thread",
200+
)
201+
202+
# Reset the node to create the iterator
203+
node.reset()
204+
205+
# We need to consume some items to ensure the ThreadPoolExecutor is created
206+
# and the worker threads are started
207+
for _ in range(5):
208+
next(node)
209+
210+
# Use mock.patch to intercept the ThreadPoolExecutor.shutdown method
211+
with mock.patch("concurrent.futures.ThreadPoolExecutor.shutdown") as mock_shutdown:
212+
# Delete the node, which should trigger the shutdown of the ThreadPoolExecutor
213+
del node
214+
215+
# Verify that shutdown was called
216+
mock_shutdown.assert_called()
217+
218+
def test_thread_pool_executor_shutdown_on_exception(self):
219+
"""Test that the ThreadPoolExecutor is properly shut down when the iterator is deleted."""
220+
# Create a ParallelMapper with method="thread"
221+
src = MockSource(num_samples=10)
222+
node = ParallelMapper(
223+
src,
224+
udf_raises,
225+
num_workers=2,
226+
method="thread",
227+
)
228+
229+
# Reset the node to create the iterator
230+
node.reset()
231+
232+
# Use mock.patch to intercept the ThreadPoolExecutor.shutdown method
233+
with mock.patch("concurrent.futures.ThreadPoolExecutor.shutdown") as mock_shutdown:
234+
# Consumer the iterator to ensure the ThreadPoolExecutor is created
235+
# and exception is raised
236+
try:
237+
next(node)
238+
except ValueError:
239+
pass
240+
241+
# Verify that shutdown was called
242+
mock_shutdown.assert_called()

torchdata/nodes/map.py

Lines changed: 58 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import queue
88
import threading
99
import time
10+
11+
from concurrent.futures import ThreadPoolExecutor
1012
from typing import Any, Callable, Dict, Generic, Iterator, List, Literal, Optional, Protocol, Sequence, TypeVar, Union
1113

1214
import torch.multiprocessing as mp
@@ -65,7 +67,11 @@ def __call__(self, xlist: Sequence[X]) -> Sequence[T]:
6567
return [self.map_fn(x) for x in xlist]
6668

6769

68-
def _sort_worker(in_q: Union[queue.Queue, mp.Queue], out_q: queue.Queue, stop_event: threading.Event):
70+
def _sort_worker(
71+
in_q: Union[queue.Queue, mp.Queue],
72+
out_q: queue.Queue,
73+
stop_event: threading.Event,
74+
):
6975
buffer: Dict[int, Any] = {}
7076
cur_idx = 0
7177
while not stop_event.is_set():
@@ -91,6 +97,25 @@ def _sort_worker(in_q: Union[queue.Queue, mp.Queue], out_q: queue.Queue, stop_ev
9197
cur_idx += 1
9298

9399

100+
def _transformation_pool(
101+
pool: ThreadPoolExecutor,
102+
num_workers: int,
103+
in_q: queue.Queue,
104+
out_q: queue.Queue,
105+
map_fn: Callable[[X], T],
106+
stop_event: threading.Event,
107+
):
108+
for worker_id in range(num_workers):
109+
args = (
110+
worker_id,
111+
in_q,
112+
out_q,
113+
map_fn,
114+
stop_event,
115+
)
116+
pool.submit(_apply_udf, *args)
117+
118+
94119
class _InlineMapperIter(Iterator[T]):
95120
"""Non-Parallel implementation of Mapper"""
96121

@@ -186,40 +211,49 @@ def __init__(
186211
name="read_thread(target=_populate_queue)",
187212
daemon=self.daemonic_reading,
188213
)
189-
self._workers: List[Union[threading.Thread, mp.Process]] = []
190-
for worker_id in range(self.num_workers):
191-
args = (
192-
worker_id,
214+
self._read_thread.start()
215+
216+
if self.method == "thread":
217+
self.pool = ThreadPoolExecutor(max_workers=self.num_workers)
218+
219+
_transformation_pool(
220+
self.pool,
221+
self.num_workers,
193222
self._in_q,
194223
self._intermed_q,
195224
self.map_fn,
196-
self._stop if self.method == "thread" else self._mp_stop,
225+
self._stop,
197226
)
198-
self._workers.append(
199-
threading.Thread(
200-
target=_apply_udf,
201-
args=args,
202-
daemon=True,
203-
name=f"worker_thread_{worker_id}(target=_apply_udf)",
227+
228+
elif self.method == "process":
229+
self._workers: List[mp.Process] = []
230+
for worker_id in range(self.num_workers):
231+
args = (
232+
worker_id,
233+
self._in_q,
234+
self._intermed_q,
235+
self.map_fn,
236+
self._mp_stop,
204237
)
205-
if self.method == "thread"
206-
else mp_context.Process(target=_apply_udf, args=args, daemon=True)
207-
)
238+
self._workers.append(mp_context.Process(target=_apply_udf, args=args, daemon=True))
239+
for t in self._workers:
240+
t.start()
208241

209242
self._out_q = self._intermed_q
210243
if self.in_order:
211244
self._sort_q: queue.Queue = queue.Queue()
212245
self._sort_thread = threading.Thread(
213246
target=_sort_worker,
214-
args=(self._intermed_q, self._sort_q, self._stop),
247+
args=(
248+
self._intermed_q,
249+
self._sort_q,
250+
self._stop,
251+
),
215252
daemon=True,
216253
name="sort_thread(target=_sort_worker)",
217254
)
218255
self._out_q = self._sort_q
219256

220-
self._read_thread.start()
221-
for t in self._workers:
222-
t.start()
223257
if self.in_order:
224258
self._sort_thread.start()
225259

@@ -260,6 +294,7 @@ def __next__(self) -> T:
260294
elif isinstance(item, ExceptionWrapper):
261295
if not isinstance(item, StartupExceptionWrapper):
262296
self._sem.release()
297+
self._shutdown()
263298
item.reraise()
264299

265300
self._steps_since_snapshot += 1
@@ -286,12 +321,14 @@ def _shutdown(self):
286321
self._mp_stop.set()
287322
if hasattr(self, "_read_thread") and self._read_thread.is_alive():
288323
self._read_thread.join(timeout=QUEUE_TIMEOUT * 5)
289-
if hasattr(self, "_sort_thread") and self._sort_thread.is_alive():
290-
self._sort_thread.join(timeout=QUEUE_TIMEOUT * 5)
324+
if hasattr(self, "pool"):
325+
self.pool.shutdown(wait=True)
291326
if hasattr(self, "_workers"):
292327
for t in self._workers:
293328
if t.is_alive():
294329
t.join(timeout=QUEUE_TIMEOUT * 5)
330+
if hasattr(self, "_sort_thread") and self._sort_thread.is_alive():
331+
self._sort_thread.join(timeout=QUEUE_TIMEOUT * 5)
295332

296333

297334
class _ParallelMapperImpl(BaseNode[T]):

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy