Skip to main content

iris_core/multicore/
dedicated_worker.rs

1use super::{pin_thread_to_core, ChannelDispatcher, SubscriptionStats};
2use crate::CoreId;
3use crossbeam::channel::{Receiver, Select, TryRecvError};
4use serde::Serialize;
5use std::fs::File;
6use std::io::{Error, Result, Write};
7use std::path::{Path, PathBuf};
8use std::sync::{
9    atomic::{AtomicBool, Ordering},
10    Arc, Barrier,
11};
12use std::thread::{self, sleep, JoinHandle};
13use std::time::Duration;
14
15/// Spawns worker threads dedicated to a single dispatcher, with all threads using the same handler function.
16/// Optimizes for single-receiver scenarios by avoiding select overhead.
17pub struct DedicatedWorkerThreadSpawner<T, F>
18where
19    F: Fn(T) + Send + Sync + 'static,
20{
21    worker_cores: Option<Vec<CoreId>>,
22    dispatcher: Option<Arc<ChannelDispatcher<T>>>,
23    handler: Option<F>,
24    batch_size: usize,
25}
26
27/// Handle for managing a group of dedicated worker threads.
28/// Provides methods for graceful shutdown and statistics access.
29pub struct DedicatedWorkerHandle<T>
30where
31    T: Send + 'static,
32{
33    handles: Vec<JoinHandle<()>>,
34    dispatcher: Arc<ChannelDispatcher<T>>,
35    shutdown_signal: Arc<AtomicBool>,
36}
37
38/// Handle for initializing a group of dedicated worker threads.
39impl<T: Send + 'static> DedicatedWorkerThreadSpawner<T, fn(T)> {
40    /// Creates a new spawner with a no-op handler function.
41    pub fn new() -> Self {
42        Self {
43            worker_cores: None,
44            dispatcher: None,
45            handler: Some(|_t: T| {}),
46            batch_size: 1,
47        }
48    }
49}
50
51impl<T: Send + 'static> Default for DedicatedWorkerThreadSpawner<T, fn(T)> {
52    fn default() -> Self {
53        Self::new()
54    }
55}
56
57impl<T: Send + 'static, F> DedicatedWorkerThreadSpawner<T, F>
58where
59    F: Fn(T) + Send + Sync + Clone + 'static,
60{
61    /// Sets the CPU cores that worker threads will be pinned to.
62    pub fn set_cores(mut self, cores: Vec<CoreId>) -> Self {
63        self.worker_cores = Some(cores);
64        self
65    }
66
67    /// Sets the batch size for processing messages.
68    pub fn set_batch_size(mut self, batch_size: usize) -> Self {
69        self.batch_size = batch_size.max(1);
70        self
71    }
72
73    /// Sets the single dispatcher that all worker threads will process subscriptions from.
74    pub fn set_dispatcher(mut self, dispatcher: Arc<ChannelDispatcher<T>>) -> Self {
75        self.dispatcher = Some(dispatcher);
76        self
77    }
78
79    /// Sets the handler function that will process all subscriptions. Changes the function type parameter.
80    pub fn set_handler<G>(self, handler: G) -> DedicatedWorkerThreadSpawner<T, G>
81    where
82        G: Fn(T) + Send + Sync + Clone + 'static,
83    {
84        DedicatedWorkerThreadSpawner {
85            worker_cores: self.worker_cores,
86            dispatcher: self.dispatcher,
87            handler: Some(handler),
88            batch_size: self.batch_size,
89        }
90    }
91
92    /// Spawns worker threads on configured cores. Returns a handle for managing the worker group.
93    /// Uses a barrier to ensure all threads are ready before returning.
94    pub fn run(self) -> DedicatedWorkerHandle<T>
95    where
96        F: 'static,
97    {
98        let worker_cores = self
99            .worker_cores
100            .expect("Cores must be set via set_cores()");
101        let dispatcher = self
102            .dispatcher
103            .expect("Dispatcher must be set via set_dispatcher()");
104        let handler = Arc::new(
105            self.handler
106                .expect("Handler function must be set via set_handler()"),
107        );
108
109        let batch_size = self.batch_size;
110        let receivers = Arc::new(dispatcher.receivers());
111        let num_threads = worker_cores.len();
112        let shutdown_signal = Arc::new(AtomicBool::new(false));
113
114        // Barrier to ensure all threads are spawned before returning
115        let startup_barrier = Arc::new(Barrier::new(num_threads + 1)); // +1 for main thread
116
117        let mut handles = Vec::with_capacity(num_threads);
118        for core in worker_cores {
119            let receivers_ref = Arc::clone(&receivers);
120            let handler_ref = Arc::clone(&handler);
121            let dispatcher_ref = Arc::clone(&dispatcher);
122            let barrier_ref = Arc::clone(&startup_barrier);
123            let shutdown_ref = Arc::clone(&shutdown_signal);
124
125            let handle = thread::spawn(move || {
126                if let Err(e) = pin_thread_to_core(core.raw()) {
127                    eprintln!("Failed to pin thread to core {core}: {e}");
128                }
129
130                // Signal that this thread is ready
131                barrier_ref.wait();
132
133                Self::run_worker_loop(
134                    &receivers_ref,
135                    &handler_ref,
136                    &dispatcher_ref,
137                    batch_size,
138                    &shutdown_ref,
139                );
140            });
141
142            handles.push(handle);
143        }
144
145        // Wait for all threads to be ready
146        startup_barrier.wait();
147
148        DedicatedWorkerHandle {
149            handles,
150            dispatcher,
151            shutdown_signal,
152        }
153    }
154
155    /// Process channel messages in batches.
156    fn process_batch(batch: Vec<T>, handler: &F, dispatcher: &Arc<ChannelDispatcher<T>>) {
157        if batch.is_empty() {
158            return;
159        }
160
161        let batch_size = batch.len() as u64;
162
163        dispatcher
164            .stats()
165            .actively_processing
166            .fetch_add(batch_size, Ordering::Relaxed);
167
168        for data in batch {
169            handler(data);
170        }
171
172        dispatcher
173            .stats()
174            .processed
175            .fetch_add(batch_size, Ordering::Relaxed);
176        dispatcher
177            .stats()
178            .actively_processing
179            .fetch_sub(batch_size, Ordering::Relaxed);
180    }
181
182    /// Main worker loop that uses crossbeam Select to efficiently wait on multiple channels.
183    /// Routes each subscription to the appropriate handler and updates processing statistics.
184    fn run_worker_loop(
185        receivers: &[Arc<Receiver<T>>],
186        handler: &F,
187        dispatcher: &Arc<ChannelDispatcher<T>>,
188        batch_size: usize,
189        shutdown_signal: &Arc<AtomicBool>,
190    ) {
191        let mut select = Select::new();
192        for receiver in receivers {
193            select.recv(receiver);
194        }
195
196        loop {
197            if shutdown_signal.load(Ordering::Relaxed) {
198                break;
199            }
200
201            let oper = select.select();
202            let index = oper.index();
203            let receiver = &receivers[index];
204
205            let mut batch = Vec::with_capacity(batch_size);
206            let mut recv_error: Option<TryRecvError> = None;
207
208            match oper.recv(receiver) {
209                Ok(msg) => {
210                    batch.push(msg);
211                }
212                Err(_) => {
213                    // Channel is disconnected, exit the loop
214                    break;
215                }
216            }
217
218            for _ in 0..batch_size {
219                match receiver.try_recv() {
220                    Ok(msg) => {
221                        batch.push(msg);
222                    }
223                    Err(e) => {
224                        recv_error = Some(e);
225                        break;
226                    }
227                }
228            }
229
230            if !batch.is_empty() {
231                Self::process_batch(batch, handler, dispatcher);
232            }
233
234            if let Some(err) = recv_error {
235                match err {
236                    TryRecvError::Empty => {
237                        continue; // Channel is empty, go back to select
238                    }
239                    TryRecvError::Disconnected => {
240                        break; // Channel is disconnected, exit the loop
241                    }
242                }
243            }
244        }
245    }
246}
247
248impl<T> DedicatedWorkerHandle<T>
249where
250    T: Send + 'static,
251{
252    /// Blocks until all queues are empty and no messages are actively processing.
253    pub fn wait_for_completion(&self) {
254        let receivers = self.dispatcher.receivers();
255
256        loop {
257            let queues_empty = receivers.iter().all(|r| r.is_empty());
258            let active_handlers = self.dispatcher.stats().get_actively_processing();
259
260            if queues_empty && active_handlers == 0 {
261                break;
262            }
263
264            // Small sleep to avoid busy waiting
265            sleep(Duration::from_millis(10));
266        }
267    }
268
269    /// Gracefully shuts down all worker threads.  
270    /// If `flush_dir` is provided, all channel contents are flushed to disk.  
271    /// Otherwise, it waits for every item in the channels to be processed.  
272    /// In the non-flush case, this may appear to stall, since the
273    /// function blocks until all pending work is completed. Returns
274    /// the final statistics snapshot.
275    pub fn shutdown(mut self, flush_dir: Option<&PathBuf>) -> SubscriptionStats
276    where
277        T: Serialize,
278    {
279        if let Some(dir) = flush_dir {
280            self.flush_shutdown(dir);
281        } else {
282            self.complete_shutdown();
283        }
284
285        self.dispatcher.stats().snapshot()
286    }
287
288    fn complete_shutdown(&mut self) {
289        self.wait_for_completion();
290        self.shutdown_signal.store(true, Ordering::SeqCst);
291
292        self.dispatcher.close_channels();
293
294        for (i, handle) in self.handles.drain(..).enumerate() {
295            if let Err(e) = handle.join() {
296                eprintln!("Thread {i} error: {e:?}");
297            }
298        }
299    }
300
301    fn flush_shutdown(&mut self, flush_dir: &Path)
302    where
303        T: Serialize,
304    {
305        self.shutdown_signal.store(true, Ordering::Relaxed);
306
307        sleep(Duration::from_millis(50));
308        let mut flushed_messages = Vec::new();
309
310        for receiver in self.dispatcher.receivers().iter() {
311            while let Ok(message) = receiver.try_recv() {
312                flushed_messages.push(message);
313            }
314        }
315
316        let message_count = flushed_messages.len() as u64;
317        if message_count == 0 {
318            return;
319        }
320
321        let file_path = flush_dir.join(format!("{}.json", self.dispatcher.name()));
322
323        if flush_messages(&flushed_messages, &file_path).is_ok() {
324            println!(
325                "Successfully flushed {message_count} messages to: {}",
326                file_path.display()
327            );
328            self.dispatcher
329                .stats()
330                .flushed
331                .fetch_add(message_count, Ordering::Relaxed);
332        } else {
333            eprintln!("Error occurred when flushing, dropped {message_count} messages");
334            self.dispatcher
335                .stats()
336                .dropped
337                .fetch_add(message_count, Ordering::Relaxed);
338        }
339    }
340}
341
342/// Writes messages to disk as formatted JSON.
343fn flush_messages<T: Serialize>(messages: &[T], path: &PathBuf) -> Result<()> {
344    if let Some(parent) = path.parent() {
345        std::fs::create_dir_all(parent)?;
346    }
347
348    let mut file = File::create(path)?;
349    let json_str = serde_json::to_string_pretty(messages).map_err(Error::other)?;
350    writeln!(file, "{json_str}")?;
351    Ok(())
352}