iris_core/multicore/
shared_worker.rs1use super::{pin_thread_to_core, ChannelDispatcher, SubscriptionStats};
2use crate::CoreId;
3use crossbeam::channel::{Receiver, Select, TryRecvError};
4use serde::Serialize;
5use std::fs::File;
6use std::io::{Error, Result, Write};
7use std::path::{Path, PathBuf};
8use std::sync::{
9 atomic::{AtomicBool, Ordering},
10 Arc, Barrier,
11};
12use std::thread::{self, sleep, JoinHandle};
13use std::time::Duration;
14
15pub struct SharedWorkerThreadSpawner<T>
18where
19 T: Send + 'static,
20{
21 worker_cores: Option<Vec<CoreId>>,
22 dispatchers: Vec<Arc<ChannelDispatcher<T>>>,
23 handlers: Vec<Box<dyn Fn(T) + Send + Sync>>,
24 batch_size: usize,
25}
26
27pub struct SharedWorkerHandle<T>
30where
31 T: Send + 'static,
32{
33 handles: Vec<JoinHandle<()>>,
34 dispatchers: Vec<Arc<ChannelDispatcher<T>>>,
35 shutdown_signal: Arc<AtomicBool>,
36}
37
38impl<T> SharedWorkerThreadSpawner<T>
40where
41 T: Send + Clone + 'static,
42{
43 pub fn new() -> Self {
45 Self {
46 worker_cores: None,
47 dispatchers: Vec::new(),
48 handlers: Vec::new(),
49 batch_size: 1,
50 }
51 }
52
53 pub fn set_cores(mut self, cores: Vec<CoreId>) -> Self {
55 self.worker_cores = Some(cores);
56 self
57 }
58
59 pub fn set_batch_size(mut self, batch_size: usize) -> Self {
61 self.batch_size = batch_size.max(1);
62 self
63 }
64
65 pub fn add_dispatcher<F>(mut self, dispatcher: Arc<ChannelDispatcher<T>>, handler: F) -> Self
67 where
68 F: Fn(T) + Send + Sync + 'static,
69 {
70 self.dispatchers.push(dispatcher);
71 self.handlers.push(Box::new(handler));
72 self
73 }
74
75 fn build_tagged_receivers(&self) -> Vec<(usize, Arc<Receiver<T>>)> {
78 let mut tagged_receivers = Vec::new();
79
80 for (index, dispatcher) in self.dispatchers.iter().enumerate() {
81 let receivers = dispatcher.receivers();
82 for receiver in receivers {
83 tagged_receivers.push((index, receiver));
84 }
85 }
86
87 tagged_receivers
88 }
89
90 pub fn run(self) -> SharedWorkerHandle<T> {
94 let tagged_receivers = Arc::new(self.build_tagged_receivers());
95 let handlers = Arc::new(self.handlers);
96 let dispatchers = Arc::new(self.dispatchers);
97 let batch_size = self.batch_size;
98 let worker_cores = self
99 .worker_cores
100 .expect("Cores must be set via set_cores()");
101
102 let num_threads = worker_cores.len();
103 let shutdown_signal = Arc::new(AtomicBool::new(false));
104
105 let startup_barrier = Arc::new(Barrier::new(num_threads + 1)); let mut handles = Vec::with_capacity(num_threads);
109 for core in worker_cores {
110 let tagged_receivers_ref = Arc::clone(&tagged_receivers);
111 let handlers_ref = Arc::clone(&handlers);
112 let dispatchers_ref = dispatchers.clone();
113 let barrier_ref = Arc::clone(&startup_barrier);
114 let shutdown_ref = Arc::clone(&shutdown_signal);
115
116 let handle = thread::spawn(move || {
117 if let Err(e) = pin_thread_to_core(core.raw()) {
118 eprintln!("Failed to pin thread to core {core}: {e}");
119 }
120
121 barrier_ref.wait();
123
124 Self::run_worker_loop(
125 &tagged_receivers_ref,
126 &handlers_ref,
127 &dispatchers_ref,
128 batch_size,
129 &shutdown_ref,
130 );
131 });
132
133 handles.push(handle);
134 }
135
136 startup_barrier.wait();
138
139 SharedWorkerHandle {
140 handles,
141 dispatchers: dispatchers.to_vec(),
142 shutdown_signal,
143 }
144 }
145
146 fn process_batch(
148 batch: Vec<T>,
149 handler: &(dyn Fn(T) + Send + Sync),
150 dispatcher: &Arc<ChannelDispatcher<T>>,
151 ) {
152 if batch.is_empty() {
153 return;
154 }
155
156 let batch_size = batch.len() as u64;
157
158 dispatcher
159 .stats()
160 .actively_processing
161 .fetch_add(batch_size, Ordering::Relaxed);
162
163 for data in batch {
164 handler(data);
165 }
166
167 dispatcher
168 .stats()
169 .processed
170 .fetch_add(batch_size, Ordering::Relaxed);
171 dispatcher
172 .stats()
173 .actively_processing
174 .fetch_sub(batch_size, Ordering::Relaxed);
175 }
176
177 fn run_worker_loop(
180 tagged_receivers: &[(usize, Arc<Receiver<T>>)],
181 handlers: &[Box<dyn Fn(T) + Send + Sync>],
182 dispatchers: &[Arc<ChannelDispatcher<T>>],
183 batch_size: usize,
184 shutdown_signal: &Arc<AtomicBool>,
185 ) {
186 let mut select = Select::new();
187 for (_, receiver) in tagged_receivers.iter() {
188 select.recv(receiver);
189 }
190
191 loop {
192 if shutdown_signal.load(Ordering::Relaxed) {
193 break;
194 }
195
196 let oper = select.select();
197 let oper_index = oper.index();
198 let (handler_index, receiver) = &tagged_receivers[oper_index];
199 let handler = &handlers[*handler_index];
200 let dispatcher = &dispatchers[*handler_index];
201
202 let mut batch = Vec::with_capacity(batch_size);
203 let mut recv_error: Option<TryRecvError> = None;
204
205 match oper.recv(receiver) {
206 Ok(msg) => {
207 batch.push(msg);
208 }
209 Err(_) => {
210 break;
212 }
213 }
214
215 for _ in 0..batch_size {
216 match receiver.try_recv() {
217 Ok(msg) => {
218 batch.push(msg);
219 }
220 Err(e) => {
221 recv_error = Some(e);
222 break;
223 }
224 }
225 }
226
227 if !batch.is_empty() {
228 Self::process_batch(batch, handler.as_ref(), dispatcher);
229 }
230
231 if let Some(err) = recv_error {
232 match err {
233 TryRecvError::Empty => {
234 continue; }
236 TryRecvError::Disconnected => {
237 break; }
239 }
240 }
241 }
242 }
243}
244
245impl<T> Default for SharedWorkerThreadSpawner<T>
246where
247 T: Send + Clone + 'static,
248{
249 fn default() -> Self {
250 Self::new()
251 }
252}
253
254impl<T> SharedWorkerHandle<T>
255where
256 T: Send + 'static,
257{
258 pub fn wait_for_completion(&self) {
260 loop {
261 let all_complete = self.dispatchers.iter().all(|dispatcher| {
262 let receivers = dispatcher.receivers();
263 let queues_empty = receivers.iter().all(|r| r.is_empty());
264 let active_handlers = dispatcher.stats().get_actively_processing();
265
266 queues_empty && active_handlers == 0
267 });
268
269 if all_complete {
270 break;
271 }
272
273 sleep(Duration::from_millis(10));
275 }
276 }
277
278 pub fn shutdown(mut self, flush_dir: Option<&PathBuf>) -> Vec<SubscriptionStats>
285 where
286 T: Serialize,
287 {
288 if let Some(dir) = flush_dir {
289 self.flush_shutdown(dir);
290 } else {
291 self.complete_shutdown();
292 }
293
294 self.dispatchers
295 .iter()
296 .map(|dispatcher| dispatcher.stats().snapshot())
297 .collect()
298 }
299
300 fn complete_shutdown(&mut self) {
301 self.wait_for_completion();
302 self.shutdown_signal.store(true, Ordering::SeqCst);
303
304 for dispatcher in &self.dispatchers {
305 dispatcher.close_channels();
306 }
307
308 for (i, handle) in self.handles.drain(..).enumerate() {
309 if let Err(e) = handle.join() {
310 eprintln!("Thread {i} error: {e:?}");
311 }
312 }
313 }
314
315 fn flush_shutdown(&mut self, flush_dir: &Path)
316 where
317 T: Serialize,
318 {
319 self.shutdown_signal.store(true, Ordering::SeqCst);
320
321 for dispatcher in &self.dispatchers {
322 dispatcher.close_channels();
323 }
324
325 for (i, handle) in self.handles.drain(..).enumerate() {
326 if let Err(e) = handle.join() {
327 eprintln!("Thread {i} error: {e:?}");
328 }
329 }
330
331 for (i, dispatcher) in self.dispatchers.iter().enumerate() {
332 let mut flushed_messages = Vec::new();
333
334 let receivers = dispatcher.receivers();
335 for receiver in receivers.iter() {
336 while let Ok(message) = receiver.try_recv() {
337 flushed_messages.push(message);
338 }
339 }
340
341 let message_count = flushed_messages.len() as u64;
342 if message_count == 0 {
343 continue;
344 }
345
346 let file_path = flush_dir.join(format!("{}.json", dispatcher.name()));
347
348 if flush_messages(&flushed_messages, &file_path).is_ok() {
349 println!(
350 "Dispatcher {i}: flushed {message_count} messages to {}",
351 file_path.display()
352 );
353 dispatcher
354 .stats()
355 .flushed
356 .fetch_add(message_count, Ordering::Relaxed);
357 } else {
358 eprintln!("Dispatcher {i}: error flushing, dropped {message_count} messages");
359 dispatcher
360 .stats()
361 .dropped
362 .fetch_add(message_count, Ordering::Relaxed);
363 }
364 }
365 }
366}
367
368fn flush_messages<T: Serialize>(messages: &[T], path: &PathBuf) -> Result<()> {
370 if let Some(parent) = path.parent() {
371 std::fs::create_dir_all(parent)?;
372 }
373
374 let mut file = File::create(path)?;
375 let json_str = serde_json::to_string_pretty(messages).map_err(Error::other)?;
376 writeln!(file, "{json_str}")?;
377 Ok(())
378}