Skip to main content

iris_core/multicore/
shared_worker.rs

1use super::{pin_thread_to_core, ChannelDispatcher, SubscriptionStats};
2use crate::CoreId;
3use crossbeam::channel::{Receiver, Select, TryRecvError};
4use serde::Serialize;
5use std::fs::File;
6use std::io::{Error, Result, Write};
7use std::path::{Path, PathBuf};
8use std::sync::{
9    atomic::{AtomicBool, Ordering},
10    Arc, Barrier,
11};
12use std::thread::{self, sleep, JoinHandle};
13use std::time::Duration;
14
15/// Spawns worker threads that share multiple dispatchers, with each thread handling subscriptions
16/// from all configured dispatchers using different handlers per dispatcher type.
17pub struct SharedWorkerThreadSpawner<T>
18where
19    T: Send + 'static,
20{
21    worker_cores: Option<Vec<CoreId>>,
22    dispatchers: Vec<Arc<ChannelDispatcher<T>>>,
23    handlers: Vec<Box<dyn Fn(T) + Send + Sync>>,
24    batch_size: usize,
25}
26
27/// Handle for managing a group of shared worker threads.
28/// Provides methods for graceful shutdown and statistics access.
29pub struct SharedWorkerHandle<T>
30where
31    T: Send + 'static,
32{
33    handles: Vec<JoinHandle<()>>,
34    dispatchers: Vec<Arc<ChannelDispatcher<T>>>,
35    shutdown_signal: Arc<AtomicBool>,
36}
37
38/// Handle for initializing a group of shared worker threads.
39impl<T> SharedWorkerThreadSpawner<T>
40where
41    T: Send + Clone + 'static,
42{
43    /// Creates a new spawner with no cores, dispatchers, or handlers configured.
44    pub fn new() -> Self {
45        Self {
46            worker_cores: None,
47            dispatchers: Vec::new(),
48            handlers: Vec::new(),
49            batch_size: 1,
50        }
51    }
52
53    /// Sets the CPU cores that worker threads will be pinned to.
54    pub fn set_cores(mut self, cores: Vec<CoreId>) -> Self {
55        self.worker_cores = Some(cores);
56        self
57    }
58
59    /// Sets the batch size for processing messages.
60    pub fn set_batch_size(mut self, batch_size: usize) -> Self {
61        self.batch_size = batch_size.max(1);
62        self
63    }
64
65    /// Adds a dispatcher-handler pair. Each dispatcher's subscriptions will be processed by its corresponding handler.
66    pub fn add_dispatcher<F>(mut self, dispatcher: Arc<ChannelDispatcher<T>>, handler: F) -> Self
67    where
68        F: Fn(T) + Send + Sync + 'static,
69    {
70        self.dispatchers.push(dispatcher);
71        self.handlers.push(Box::new(handler));
72        self
73    }
74
75    /// Builds a flattened list of all receivers tagged with their dispatcher index.
76    /// This allows workers to know which handler to use for each received subscription.
77    fn build_tagged_receivers(&self) -> Vec<(usize, Arc<Receiver<T>>)> {
78        let mut tagged_receivers = Vec::new();
79
80        for (index, dispatcher) in self.dispatchers.iter().enumerate() {
81            let receivers = dispatcher.receivers();
82            for receiver in receivers {
83                tagged_receivers.push((index, receiver));
84            }
85        }
86
87        tagged_receivers
88    }
89
90    /// Spawns worker threads on the configured cores. Each thread processes subscriptions
91    /// from all dispatchers using a select operation to handle whichever channel has data available.
92    /// Returns a handle for managing the worker group and uses a barrier to ensure all threads are ready.
93    pub fn run(self) -> SharedWorkerHandle<T> {
94        let tagged_receivers = Arc::new(self.build_tagged_receivers());
95        let handlers = Arc::new(self.handlers);
96        let dispatchers = Arc::new(self.dispatchers);
97        let batch_size = self.batch_size;
98        let worker_cores = self
99            .worker_cores
100            .expect("Cores must be set via set_cores()");
101
102        let num_threads = worker_cores.len();
103        let shutdown_signal = Arc::new(AtomicBool::new(false));
104
105        // Barrier to ensure all threads are spawned before returning
106        let startup_barrier = Arc::new(Barrier::new(num_threads + 1)); // +1 for main thread
107
108        let mut handles = Vec::with_capacity(num_threads);
109        for core in worker_cores {
110            let tagged_receivers_ref = Arc::clone(&tagged_receivers);
111            let handlers_ref = Arc::clone(&handlers);
112            let dispatchers_ref = dispatchers.clone();
113            let barrier_ref = Arc::clone(&startup_barrier);
114            let shutdown_ref = Arc::clone(&shutdown_signal);
115
116            let handle = thread::spawn(move || {
117                if let Err(e) = pin_thread_to_core(core.raw()) {
118                    eprintln!("Failed to pin thread to core {core}: {e}");
119                }
120
121                // Signal that this thread is ready
122                barrier_ref.wait();
123
124                Self::run_worker_loop(
125                    &tagged_receivers_ref,
126                    &handlers_ref,
127                    &dispatchers_ref,
128                    batch_size,
129                    &shutdown_ref,
130                );
131            });
132
133            handles.push(handle);
134        }
135
136        // Wait for all threads to be ready
137        startup_barrier.wait();
138
139        SharedWorkerHandle {
140            handles,
141            dispatchers: dispatchers.to_vec(),
142            shutdown_signal,
143        }
144    }
145
146    /// Process channel messages in batches.
147    fn process_batch(
148        batch: Vec<T>,
149        handler: &(dyn Fn(T) + Send + Sync),
150        dispatcher: &Arc<ChannelDispatcher<T>>,
151    ) {
152        if batch.is_empty() {
153            return;
154        }
155
156        let batch_size = batch.len() as u64;
157
158        dispatcher
159            .stats()
160            .actively_processing
161            .fetch_add(batch_size, Ordering::Relaxed);
162
163        for data in batch {
164            handler(data);
165        }
166
167        dispatcher
168            .stats()
169            .processed
170            .fetch_add(batch_size, Ordering::Relaxed);
171        dispatcher
172            .stats()
173            .actively_processing
174            .fetch_sub(batch_size, Ordering::Relaxed);
175    }
176
177    /// Main worker loop that uses crossbeam Select to efficiently wait on multiple channels.
178    /// Routes each subscription to the appropriate handler and updates processing statistics.
179    fn run_worker_loop(
180        tagged_receivers: &[(usize, Arc<Receiver<T>>)],
181        handlers: &[Box<dyn Fn(T) + Send + Sync>],
182        dispatchers: &[Arc<ChannelDispatcher<T>>],
183        batch_size: usize,
184        shutdown_signal: &Arc<AtomicBool>,
185    ) {
186        let mut select = Select::new();
187        for (_, receiver) in tagged_receivers.iter() {
188            select.recv(receiver);
189        }
190
191        loop {
192            if shutdown_signal.load(Ordering::Relaxed) {
193                break;
194            }
195
196            let oper = select.select();
197            let oper_index = oper.index();
198            let (handler_index, receiver) = &tagged_receivers[oper_index];
199            let handler = &handlers[*handler_index];
200            let dispatcher = &dispatchers[*handler_index];
201
202            let mut batch = Vec::with_capacity(batch_size);
203            let mut recv_error: Option<TryRecvError> = None;
204
205            match oper.recv(receiver) {
206                Ok(msg) => {
207                    batch.push(msg);
208                }
209                Err(_) => {
210                    // Channel is disconnected, exit the loop
211                    break;
212                }
213            }
214
215            for _ in 0..batch_size {
216                match receiver.try_recv() {
217                    Ok(msg) => {
218                        batch.push(msg);
219                    }
220                    Err(e) => {
221                        recv_error = Some(e);
222                        break;
223                    }
224                }
225            }
226
227            if !batch.is_empty() {
228                Self::process_batch(batch, handler.as_ref(), dispatcher);
229            }
230
231            if let Some(err) = recv_error {
232                match err {
233                    TryRecvError::Empty => {
234                        continue; // Channel is empty, go back to select
235                    }
236                    TryRecvError::Disconnected => {
237                        break; // Channel closed, exit the loop
238                    }
239                }
240            }
241        }
242    }
243}
244
245impl<T> Default for SharedWorkerThreadSpawner<T>
246where
247    T: Send + Clone + 'static,
248{
249    fn default() -> Self {
250        Self::new()
251    }
252}
253
254impl<T> SharedWorkerHandle<T>
255where
256    T: Send + 'static,
257{
258    /// Blocks until all queues are empty and no messages are actively processing.
259    pub fn wait_for_completion(&self) {
260        loop {
261            let all_complete = self.dispatchers.iter().all(|dispatcher| {
262                let receivers = dispatcher.receivers();
263                let queues_empty = receivers.iter().all(|r| r.is_empty());
264                let active_handlers = dispatcher.stats().get_actively_processing();
265
266                queues_empty && active_handlers == 0
267            });
268
269            if all_complete {
270                break;
271            }
272
273            // Small sleep to avoid busy waiting
274            sleep(Duration::from_millis(10));
275        }
276    }
277
278    /// Gracefully shuts down all worker threads.
279    /// If `flush_dir` is provided, all channel contents are flushed to disk.
280    /// Otherwise, it waits for every item in the channels to be processed.
281    /// In the non-flush case, this may appear to stall, since the
282    /// function blocks until all pending work is completed. Returns
283    /// the final statistics snapshot.
284    pub fn shutdown(mut self, flush_dir: Option<&PathBuf>) -> Vec<SubscriptionStats>
285    where
286        T: Serialize,
287    {
288        if let Some(dir) = flush_dir {
289            self.flush_shutdown(dir);
290        } else {
291            self.complete_shutdown();
292        }
293
294        self.dispatchers
295            .iter()
296            .map(|dispatcher| dispatcher.stats().snapshot())
297            .collect()
298    }
299
300    fn complete_shutdown(&mut self) {
301        self.wait_for_completion();
302        self.shutdown_signal.store(true, Ordering::SeqCst);
303
304        for dispatcher in &self.dispatchers {
305            dispatcher.close_channels();
306        }
307
308        for (i, handle) in self.handles.drain(..).enumerate() {
309            if let Err(e) = handle.join() {
310                eprintln!("Thread {i} error: {e:?}");
311            }
312        }
313    }
314
315    fn flush_shutdown(&mut self, flush_dir: &Path)
316    where
317        T: Serialize,
318    {
319        self.shutdown_signal.store(true, Ordering::SeqCst);
320
321        for dispatcher in &self.dispatchers {
322            dispatcher.close_channels();
323        }
324
325        for (i, handle) in self.handles.drain(..).enumerate() {
326            if let Err(e) = handle.join() {
327                eprintln!("Thread {i} error: {e:?}");
328            }
329        }
330
331        for (i, dispatcher) in self.dispatchers.iter().enumerate() {
332            let mut flushed_messages = Vec::new();
333
334            let receivers = dispatcher.receivers();
335            for receiver in receivers.iter() {
336                while let Ok(message) = receiver.try_recv() {
337                    flushed_messages.push(message);
338                }
339            }
340
341            let message_count = flushed_messages.len() as u64;
342            if message_count == 0 {
343                continue;
344            }
345
346            let file_path = flush_dir.join(format!("{}.json", dispatcher.name()));
347
348            if flush_messages(&flushed_messages, &file_path).is_ok() {
349                println!(
350                    "Dispatcher {i}: flushed {message_count} messages to {}",
351                    file_path.display()
352                );
353                dispatcher
354                    .stats()
355                    .flushed
356                    .fetch_add(message_count, Ordering::Relaxed);
357            } else {
358                eprintln!("Dispatcher {i}: error flushing, dropped {message_count} messages");
359                dispatcher
360                    .stats()
361                    .dropped
362                    .fetch_add(message_count, Ordering::Relaxed);
363            }
364        }
365    }
366}
367
368/// Writes messages to disk as formatted JSON.
369fn flush_messages<T: Serialize>(messages: &[T], path: &PathBuf) -> Result<()> {
370    if let Some(parent) = path.parent() {
371        std::fs::create_dir_all(parent)?;
372    }
373
374    let mut file = File::create(path)?;
375    let json_str = serde_json::to_string_pretty(messages).map_err(Error::other)?;
376    writeln!(file, "{json_str}")?;
377    Ok(())
378}