77import queue
88import threading
99import time
10+
11+ from concurrent .futures import ThreadPoolExecutor
1012from typing import Any , Callable , Dict , Generic , Iterator , List , Literal , Optional , Protocol , Sequence , TypeVar , Union
1113
1214import torch .multiprocessing as mp
@@ -65,7 +67,11 @@ def __call__(self, xlist: Sequence[X]) -> Sequence[T]:
6567 return [self .map_fn (x ) for x in xlist ]
6668
6769
68- def _sort_worker (in_q : Union [queue .Queue , mp .Queue ], out_q : queue .Queue , stop_event : threading .Event ):
70+ def _sort_worker (
71+ in_q : Union [queue .Queue , mp .Queue ],
72+ out_q : queue .Queue ,
73+ stop_event : threading .Event ,
74+ ):
6975 buffer : Dict [int , Any ] = {}
7076 cur_idx = 0
7177 while not stop_event .is_set ():
@@ -91,6 +97,25 @@ def _sort_worker(in_q: Union[queue.Queue, mp.Queue], out_q: queue.Queue, stop_ev
9197 cur_idx += 1
9298
9399
100+ def _transformation_pool (
101+ pool : ThreadPoolExecutor ,
102+ num_workers : int ,
103+ in_q : queue .Queue ,
104+ out_q : queue .Queue ,
105+ map_fn : Callable [[X ], T ],
106+ stop_event : threading .Event ,
107+ ):
108+ for worker_id in range (num_workers ):
109+ args = (
110+ worker_id ,
111+ in_q ,
112+ out_q ,
113+ map_fn ,
114+ stop_event ,
115+ )
116+ pool .submit (_apply_udf , * args )
117+
118+
94119class _InlineMapperIter (Iterator [T ]):
95120 """Non-Parallel implementation of Mapper"""
96121
@@ -186,40 +211,49 @@ def __init__(
186211 name = "read_thread(target=_populate_queue)" ,
187212 daemon = self .daemonic_reading ,
188213 )
189- self ._workers : List [Union [threading .Thread , mp .Process ]] = []
190- for worker_id in range (self .num_workers ):
191- args = (
192- worker_id ,
214+ self ._read_thread .start ()
215+
216+ if self .method == "thread" :
217+ self .pool = ThreadPoolExecutor (max_workers = self .num_workers )
218+
219+ _transformation_pool (
220+ self .pool ,
221+ self .num_workers ,
193222 self ._in_q ,
194223 self ._intermed_q ,
195224 self .map_fn ,
196- self ._stop if self . method == "thread" else self . _mp_stop ,
225+ self ._stop ,
197226 )
198- self ._workers .append (
199- threading .Thread (
200- target = _apply_udf ,
201- args = args ,
202- daemon = True ,
203- name = f"worker_thread_{ worker_id } (target=_apply_udf)" ,
227+
228+ elif self .method == "process" :
229+ self ._workers : List [mp .Process ] = []
230+ for worker_id in range (self .num_workers ):
231+ args = (
232+ worker_id ,
233+ self ._in_q ,
234+ self ._intermed_q ,
235+ self .map_fn ,
236+ self ._mp_stop ,
204237 )
205- if self .method == "thread"
206- else mp_context . Process ( target = _apply_udf , args = args , daemon = True )
207- )
238+ self ._workers . append ( mp_context . Process ( target = _apply_udf , args = args , daemon = True ))
239+ for t in self . _workers :
240+ t . start ( )
208241
209242 self ._out_q = self ._intermed_q
210243 if self .in_order :
211244 self ._sort_q : queue .Queue = queue .Queue ()
212245 self ._sort_thread = threading .Thread (
213246 target = _sort_worker ,
214- args = (self ._intermed_q , self ._sort_q , self ._stop ),
247+ args = (
248+ self ._intermed_q ,
249+ self ._sort_q ,
250+ self ._stop ,
251+ ),
215252 daemon = True ,
216253 name = "sort_thread(target=_sort_worker)" ,
217254 )
218255 self ._out_q = self ._sort_q
219256
220- self ._read_thread .start ()
221- for t in self ._workers :
222- t .start ()
223257 if self .in_order :
224258 self ._sort_thread .start ()
225259
@@ -260,6 +294,7 @@ def __next__(self) -> T:
260294 elif isinstance (item , ExceptionWrapper ):
261295 if not isinstance (item , StartupExceptionWrapper ):
262296 self ._sem .release ()
297+ self ._shutdown ()
263298 item .reraise ()
264299
265300 self ._steps_since_snapshot += 1
@@ -286,12 +321,14 @@ def _shutdown(self):
286321 self ._mp_stop .set ()
287322 if hasattr (self , "_read_thread" ) and self ._read_thread .is_alive ():
288323 self ._read_thread .join (timeout = QUEUE_TIMEOUT * 5 )
289- if hasattr (self , "_sort_thread" ) and self . _sort_thread . is_alive ( ):
290- self ._sort_thread . join ( timeout = QUEUE_TIMEOUT * 5 )
324+ if hasattr (self , "pool" ):
325+ self .pool . shutdown ( wait = True )
291326 if hasattr (self , "_workers" ):
292327 for t in self ._workers :
293328 if t .is_alive ():
294329 t .join (timeout = QUEUE_TIMEOUT * 5 )
330+ if hasattr (self , "_sort_thread" ) and self ._sort_thread .is_alive ():
331+ self ._sort_thread .join (timeout = QUEUE_TIMEOUT * 5 )
295332
296333
297334class _ParallelMapperImpl (BaseNode [T ]):
0 commit comments