Skip to content

Commit 911d6d0

Browse files
bors[bot]cuviper
andauthored
Merge #492
492: Add ThreadPool::broadcast r=cuviper a=cuviper A broadcast runs the closure on every thread in the pool, then collects the results. It's scheduled somewhat like a very soft interrupt -- it won't preempt a thread's local work, but will run before it goes to steal from any other threads. This can be used when you want to precisely split your work per-thread, or to set or retrieve some thread-local data in the pool, e.g. #483. Co-authored-by: Josh Stone <cuviper@gmail.com>
2 parents 274499a + 9ef85cd commit 911d6d0

File tree

12 files changed

+778
-86
lines changed

12 files changed

+778
-86
lines changed

rayon-core/src/broadcast/mod.rs

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
use crate::job::{ArcJob, StackJob};
2+
use crate::registry::{Registry, WorkerThread};
3+
use crate::scope::ScopeLatch;
4+
use std::fmt;
5+
use std::marker::PhantomData;
6+
use std::sync::Arc;
7+
8+
mod test;
9+
10+
/// Executes `op` within every thread in the current threadpool. If this is
11+
/// called from a non-Rayon thread, it will execute in the global threadpool.
12+
/// Any attempts to use `join`, `scope`, or parallel iterators will then operate
13+
/// within that threadpool. When the call has completed on each thread, returns
14+
/// a vector containing all of their return values.
15+
///
16+
/// For more information, see the [`ThreadPool::broadcast()`][m] method.
17+
///
18+
/// [m]: struct.ThreadPool.html#method.broadcast
19+
pub fn broadcast<OP, R>(op: OP) -> Vec<R>
20+
where
21+
OP: Fn(BroadcastContext<'_>) -> R + Sync,
22+
R: Send,
23+
{
24+
// We assert that current registry has not terminated.
25+
unsafe { broadcast_in(op, &Registry::current()) }
26+
}
27+
28+
/// Spawns an asynchronous task on every thread in this thread-pool. This task
29+
/// will run in the implicit, global scope, which means that it may outlast the
30+
/// current stack frame -- therefore, it cannot capture any references onto the
31+
/// stack (you will likely need a `move` closure).
32+
///
33+
/// For more information, see the [`ThreadPool::spawn_broadcast()`][m] method.
34+
///
35+
/// [m]: struct.ThreadPool.html#method.spawn_broadcast
36+
pub fn spawn_broadcast<OP>(op: OP)
37+
where
38+
OP: Fn(BroadcastContext<'_>) + Send + Sync + 'static,
39+
{
40+
// We assert that current registry has not terminated.
41+
unsafe { spawn_broadcast_in(op, &Registry::current()) }
42+
}
43+
44+
/// Provides context to a closure called by `broadcast`.
45+
pub struct BroadcastContext<'a> {
46+
worker: &'a WorkerThread,
47+
48+
/// Make sure to prevent auto-traits like `Send` and `Sync`.
49+
_marker: PhantomData<&'a mut dyn Fn()>,
50+
}
51+
52+
impl<'a> BroadcastContext<'a> {
53+
pub(super) fn with<R>(f: impl FnOnce(BroadcastContext<'_>) -> R) -> R {
54+
let worker_thread = WorkerThread::current();
55+
assert!(!worker_thread.is_null());
56+
f(BroadcastContext {
57+
worker: unsafe { &*worker_thread },
58+
_marker: PhantomData,
59+
})
60+
}
61+
62+
/// Our index amongst the broadcast threads (ranges from `0..self.num_threads()`).
63+
#[inline]
64+
pub fn index(&self) -> usize {
65+
self.worker.index()
66+
}
67+
68+
/// The number of threads receiving the broadcast in the thread pool.
69+
///
70+
/// # Future compatibility note
71+
///
72+
/// Future versions of Rayon might vary the number of threads over time, but
73+
/// this method will always return the number of threads which are actually
74+
/// receiving your particular `broadcast` call.
75+
#[inline]
76+
pub fn num_threads(&self) -> usize {
77+
self.worker.registry().num_threads()
78+
}
79+
}
80+
81+
impl<'a> fmt::Debug for BroadcastContext<'a> {
82+
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
83+
fmt.debug_struct("BroadcastContext")
84+
.field("index", &self.index())
85+
.field("num_threads", &self.num_threads())
86+
.field("pool_id", &self.worker.registry().id())
87+
.finish()
88+
}
89+
}
90+
91+
/// Execute `op` on every thread in the pool. It will be executed on each
92+
/// thread when they have nothing else to do locally, before they try to
93+
/// steal work from other threads. This function will not return until all
94+
/// threads have completed the `op`.
95+
///
96+
/// Unsafe because `registry` must not yet have terminated.
97+
pub(super) unsafe fn broadcast_in<OP, R>(op: OP, registry: &Arc<Registry>) -> Vec<R>
98+
where
99+
OP: Fn(BroadcastContext<'_>) -> R + Sync,
100+
R: Send,
101+
{
102+
let f = move |injected: bool| {
103+
debug_assert!(injected);
104+
BroadcastContext::with(&op)
105+
};
106+
107+
let n_threads = registry.num_threads();
108+
let current_thread = WorkerThread::current().as_ref();
109+
let latch = ScopeLatch::with_count(n_threads, current_thread);
110+
let jobs: Vec<_> = (0..n_threads).map(|_| StackJob::new(&f, &latch)).collect();
111+
let job_refs = jobs.iter().map(|job| job.as_job_ref());
112+
113+
registry.inject_broadcast(job_refs);
114+
115+
// Wait for all jobs to complete, then collect the results, maybe propagating a panic.
116+
latch.wait(current_thread);
117+
jobs.into_iter().map(|job| job.into_result()).collect()
118+
}
119+
120+
/// Execute `op` on every thread in the pool. It will be executed on each
121+
/// thread when they have nothing else to do locally, before they try to
122+
/// steal work from other threads. This function returns immediately after
123+
/// injecting the jobs.
124+
///
125+
/// Unsafe because `registry` must not yet have terminated.
126+
pub(super) unsafe fn spawn_broadcast_in<OP>(op: OP, registry: &Arc<Registry>)
127+
where
128+
OP: Fn(BroadcastContext<'_>) + Send + Sync + 'static,
129+
{
130+
let job = ArcJob::new({
131+
let registry = Arc::clone(registry);
132+
move || {
133+
registry.catch_unwind(|| BroadcastContext::with(&op));
134+
registry.terminate(); // (*) permit registry to terminate now
135+
}
136+
});
137+
138+
let n_threads = registry.num_threads();
139+
let job_refs = (0..n_threads).map(|_| {
140+
// Ensure that registry cannot terminate until this job has executed
141+
// on each thread. This ref is decremented at the (*) above.
142+
registry.increment_terminate_count();
143+
144+
ArcJob::as_static_job_ref(&job)
145+
});
146+
147+
registry.inject_broadcast(job_refs);
148+
}

rayon-core/src/broadcast/test.rs

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
#![cfg(test)]
2+
3+
use crate::ThreadPoolBuilder;
4+
use std::sync::atomic::{AtomicUsize, Ordering};
5+
use std::sync::Arc;
6+
use std::{thread, time};
7+
8+
#[test]
9+
fn broadcast_global() {
10+
let v = crate::broadcast(|ctx| ctx.index());
11+
assert!(v.into_iter().eq(0..crate::current_num_threads()));
12+
}
13+
14+
#[test]
15+
fn spawn_broadcast_global() {
16+
let (tx, rx) = crossbeam_channel::unbounded();
17+
crate::spawn_broadcast(move |ctx| tx.send(ctx.index()).unwrap());
18+
19+
let mut v: Vec<_> = rx.into_iter().collect();
20+
v.sort_unstable();
21+
assert!(v.into_iter().eq(0..crate::current_num_threads()));
22+
}
23+
24+
#[test]
25+
fn broadcast_pool() {
26+
let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap();
27+
let v = pool.broadcast(|ctx| ctx.index());
28+
assert!(v.into_iter().eq(0..7));
29+
}
30+
31+
#[test]
32+
fn spawn_broadcast_pool() {
33+
let (tx, rx) = crossbeam_channel::unbounded();
34+
let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap();
35+
pool.spawn_broadcast(move |ctx| tx.send(ctx.index()).unwrap());
36+
37+
let mut v: Vec<_> = rx.into_iter().collect();
38+
v.sort_unstable();
39+
assert!(v.into_iter().eq(0..7));
40+
}
41+
42+
#[test]
43+
fn broadcast_self() {
44+
let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap();
45+
let v = pool.install(|| crate::broadcast(|ctx| ctx.index()));
46+
assert!(v.into_iter().eq(0..7));
47+
}
48+
49+
#[test]
50+
fn spawn_broadcast_self() {
51+
let (tx, rx) = crossbeam_channel::unbounded();
52+
let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap();
53+
pool.spawn(|| crate::spawn_broadcast(move |ctx| tx.send(ctx.index()).unwrap()));
54+
55+
let mut v: Vec<_> = rx.into_iter().collect();
56+
v.sort_unstable();
57+
assert!(v.into_iter().eq(0..7));
58+
}
59+
60+
#[test]
61+
fn broadcast_mutual() {
62+
let count = AtomicUsize::new(0);
63+
let pool1 = ThreadPoolBuilder::new().num_threads(3).build().unwrap();
64+
let pool2 = ThreadPoolBuilder::new().num_threads(7).build().unwrap();
65+
pool1.install(|| {
66+
pool2.broadcast(|_| {
67+
pool1.broadcast(|_| {
68+
count.fetch_add(1, Ordering::Relaxed);
69+
})
70+
})
71+
});
72+
assert_eq!(count.into_inner(), 3 * 7);
73+
}
74+
75+
#[test]
76+
fn spawn_broadcast_mutual() {
77+
let (tx, rx) = crossbeam_channel::unbounded();
78+
let pool1 = Arc::new(ThreadPoolBuilder::new().num_threads(3).build().unwrap());
79+
let pool2 = ThreadPoolBuilder::new().num_threads(7).build().unwrap();
80+
pool1.spawn({
81+
let pool1 = Arc::clone(&pool1);
82+
move || {
83+
pool2.spawn_broadcast(move |_| {
84+
let tx = tx.clone();
85+
pool1.spawn_broadcast(move |_| tx.send(()).unwrap())
86+
})
87+
}
88+
});
89+
assert_eq!(rx.into_iter().count(), 3 * 7);
90+
}
91+
92+
#[test]
93+
fn broadcast_mutual_sleepy() {
94+
let count = AtomicUsize::new(0);
95+
let pool1 = ThreadPoolBuilder::new().num_threads(3).build().unwrap();
96+
let pool2 = ThreadPoolBuilder::new().num_threads(7).build().unwrap();
97+
pool1.install(|| {
98+
thread::sleep(time::Duration::from_secs(1));
99+
pool2.broadcast(|_| {
100+
thread::sleep(time::Duration::from_secs(1));
101+
pool1.broadcast(|_| {
102+
thread::sleep(time::Duration::from_millis(100));
103+
count.fetch_add(1, Ordering::Relaxed);
104+
})
105+
})
106+
});
107+
assert_eq!(count.into_inner(), 3 * 7);
108+
}
109+
110+
#[test]
111+
fn spawn_broadcast_mutual_sleepy() {
112+
let (tx, rx) = crossbeam_channel::unbounded();
113+
let pool1 = Arc::new(ThreadPoolBuilder::new().num_threads(3).build().unwrap());
114+
let pool2 = ThreadPoolBuilder::new().num_threads(7).build().unwrap();
115+
pool1.spawn({
116+
let pool1 = Arc::clone(&pool1);
117+
move || {
118+
thread::sleep(time::Duration::from_secs(1));
119+
pool2.spawn_broadcast(move |_| {
120+
let tx = tx.clone();
121+
thread::sleep(time::Duration::from_secs(1));
122+
pool1.spawn_broadcast(move |_| {
123+
thread::sleep(time::Duration::from_millis(100));
124+
tx.send(()).unwrap();
125+
})
126+
})
127+
}
128+
});
129+
assert_eq!(rx.into_iter().count(), 3 * 7);
130+
}
131+
132+
#[test]
133+
fn broadcast_panic_one() {
134+
let count = AtomicUsize::new(0);
135+
let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap();
136+
let result = crate::unwind::halt_unwinding(|| {
137+
pool.broadcast(|ctx| {
138+
count.fetch_add(1, Ordering::Relaxed);
139+
if ctx.index() == 3 {
140+
panic!("Hello, world!");
141+
}
142+
})
143+
});
144+
assert_eq!(count.into_inner(), 7);
145+
assert!(result.is_err(), "broadcast panic should propagate!");
146+
}
147+
148+
#[test]
149+
fn spawn_broadcast_panic_one() {
150+
let (tx, rx) = crossbeam_channel::unbounded();
151+
let (panic_tx, panic_rx) = crossbeam_channel::unbounded();
152+
let pool = ThreadPoolBuilder::new()
153+
.num_threads(7)
154+
.panic_handler(move |e| panic_tx.send(e).unwrap())
155+
.build()
156+
.unwrap();
157+
pool.spawn_broadcast(move |ctx| {
158+
tx.send(()).unwrap();
159+
if ctx.index() == 3 {
160+
panic!("Hello, world!");
161+
}
162+
});
163+
drop(pool); // including panic_tx
164+
assert_eq!(rx.into_iter().count(), 7);
165+
assert_eq!(panic_rx.into_iter().count(), 1);
166+
}
167+
168+
#[test]
169+
fn broadcast_panic_many() {
170+
let count = AtomicUsize::new(0);
171+
let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap();
172+
let result = crate::unwind::halt_unwinding(|| {
173+
pool.broadcast(|ctx| {
174+
count.fetch_add(1, Ordering::Relaxed);
175+
if ctx.index() % 2 == 0 {
176+
panic!("Hello, world!");
177+
}
178+
})
179+
});
180+
assert_eq!(count.into_inner(), 7);
181+
assert!(result.is_err(), "broadcast panic should propagate!");
182+
}
183+
184+
#[test]
185+
fn spawn_broadcast_panic_many() {
186+
let (tx, rx) = crossbeam_channel::unbounded();
187+
let (panic_tx, panic_rx) = crossbeam_channel::unbounded();
188+
let pool = ThreadPoolBuilder::new()
189+
.num_threads(7)
190+
.panic_handler(move |e| panic_tx.send(e).unwrap())
191+
.build()
192+
.unwrap();
193+
pool.spawn_broadcast(move |ctx| {
194+
tx.send(()).unwrap();
195+
if ctx.index() % 2 == 0 {
196+
panic!("Hello, world!");
197+
}
198+
});
199+
drop(pool); // including panic_tx
200+
assert_eq!(rx.into_iter().count(), 7);
201+
assert_eq!(panic_rx.into_iter().count(), 4);
202+
}

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy