iris_core/multicore/
dedicated_worker.rs1use super::{pin_thread_to_core, ChannelDispatcher, SubscriptionStats};
2use crate::CoreId;
3use crossbeam::channel::{Receiver, Select, TryRecvError};
4use serde::Serialize;
5use std::fs::File;
6use std::io::{Error, Result, Write};
7use std::path::{Path, PathBuf};
8use std::sync::{
9 atomic::{AtomicBool, Ordering},
10 Arc, Barrier,
11};
12use std::thread::{self, sleep, JoinHandle};
13use std::time::Duration;
14
15pub struct DedicatedWorkerThreadSpawner<T, F>
18where
19 F: Fn(T) + Send + Sync + 'static,
20{
21 worker_cores: Option<Vec<CoreId>>,
22 dispatcher: Option<Arc<ChannelDispatcher<T>>>,
23 handler: Option<F>,
24 batch_size: usize,
25}
26
27pub struct DedicatedWorkerHandle<T>
30where
31 T: Send + 'static,
32{
33 handles: Vec<JoinHandle<()>>,
34 dispatcher: Arc<ChannelDispatcher<T>>,
35 shutdown_signal: Arc<AtomicBool>,
36}
37
38impl<T: Send + 'static> DedicatedWorkerThreadSpawner<T, fn(T)> {
40 pub fn new() -> Self {
42 Self {
43 worker_cores: None,
44 dispatcher: None,
45 handler: Some(|_t: T| {}),
46 batch_size: 1,
47 }
48 }
49}
50
51impl<T: Send + 'static> Default for DedicatedWorkerThreadSpawner<T, fn(T)> {
52 fn default() -> Self {
53 Self::new()
54 }
55}
56
57impl<T: Send + 'static, F> DedicatedWorkerThreadSpawner<T, F>
58where
59 F: Fn(T) + Send + Sync + Clone + 'static,
60{
61 pub fn set_cores(mut self, cores: Vec<CoreId>) -> Self {
63 self.worker_cores = Some(cores);
64 self
65 }
66
67 pub fn set_batch_size(mut self, batch_size: usize) -> Self {
69 self.batch_size = batch_size.max(1);
70 self
71 }
72
73 pub fn set_dispatcher(mut self, dispatcher: Arc<ChannelDispatcher<T>>) -> Self {
75 self.dispatcher = Some(dispatcher);
76 self
77 }
78
79 pub fn set_handler<G>(self, handler: G) -> DedicatedWorkerThreadSpawner<T, G>
81 where
82 G: Fn(T) + Send + Sync + Clone + 'static,
83 {
84 DedicatedWorkerThreadSpawner {
85 worker_cores: self.worker_cores,
86 dispatcher: self.dispatcher,
87 handler: Some(handler),
88 batch_size: self.batch_size,
89 }
90 }
91
92 pub fn run(self) -> DedicatedWorkerHandle<T>
95 where
96 F: 'static,
97 {
98 let worker_cores = self
99 .worker_cores
100 .expect("Cores must be set via set_cores()");
101 let dispatcher = self
102 .dispatcher
103 .expect("Dispatcher must be set via set_dispatcher()");
104 let handler = Arc::new(
105 self.handler
106 .expect("Handler function must be set via set_handler()"),
107 );
108
109 let batch_size = self.batch_size;
110 let receivers = Arc::new(dispatcher.receivers());
111 let num_threads = worker_cores.len();
112 let shutdown_signal = Arc::new(AtomicBool::new(false));
113
114 let startup_barrier = Arc::new(Barrier::new(num_threads + 1)); let mut handles = Vec::with_capacity(num_threads);
118 for core in worker_cores {
119 let receivers_ref = Arc::clone(&receivers);
120 let handler_ref = Arc::clone(&handler);
121 let dispatcher_ref = Arc::clone(&dispatcher);
122 let barrier_ref = Arc::clone(&startup_barrier);
123 let shutdown_ref = Arc::clone(&shutdown_signal);
124
125 let handle = thread::spawn(move || {
126 if let Err(e) = pin_thread_to_core(core.raw()) {
127 eprintln!("Failed to pin thread to core {core}: {e}");
128 }
129
130 barrier_ref.wait();
132
133 Self::run_worker_loop(
134 &receivers_ref,
135 &handler_ref,
136 &dispatcher_ref,
137 batch_size,
138 &shutdown_ref,
139 );
140 });
141
142 handles.push(handle);
143 }
144
145 startup_barrier.wait();
147
148 DedicatedWorkerHandle {
149 handles,
150 dispatcher,
151 shutdown_signal,
152 }
153 }
154
155 fn process_batch(batch: Vec<T>, handler: &F, dispatcher: &Arc<ChannelDispatcher<T>>) {
157 if batch.is_empty() {
158 return;
159 }
160
161 let batch_size = batch.len() as u64;
162
163 dispatcher
164 .stats()
165 .actively_processing
166 .fetch_add(batch_size, Ordering::Relaxed);
167
168 for data in batch {
169 handler(data);
170 }
171
172 dispatcher
173 .stats()
174 .processed
175 .fetch_add(batch_size, Ordering::Relaxed);
176 dispatcher
177 .stats()
178 .actively_processing
179 .fetch_sub(batch_size, Ordering::Relaxed);
180 }
181
182 fn run_worker_loop(
185 receivers: &[Arc<Receiver<T>>],
186 handler: &F,
187 dispatcher: &Arc<ChannelDispatcher<T>>,
188 batch_size: usize,
189 shutdown_signal: &Arc<AtomicBool>,
190 ) {
191 let mut select = Select::new();
192 for receiver in receivers {
193 select.recv(receiver);
194 }
195
196 loop {
197 if shutdown_signal.load(Ordering::Relaxed) {
198 break;
199 }
200
201 let oper = select.select();
202 let index = oper.index();
203 let receiver = &receivers[index];
204
205 let mut batch = Vec::with_capacity(batch_size);
206 let mut recv_error: Option<TryRecvError> = None;
207
208 match oper.recv(receiver) {
209 Ok(msg) => {
210 batch.push(msg);
211 }
212 Err(_) => {
213 break;
215 }
216 }
217
218 for _ in 0..batch_size {
219 match receiver.try_recv() {
220 Ok(msg) => {
221 batch.push(msg);
222 }
223 Err(e) => {
224 recv_error = Some(e);
225 break;
226 }
227 }
228 }
229
230 if !batch.is_empty() {
231 Self::process_batch(batch, handler, dispatcher);
232 }
233
234 if let Some(err) = recv_error {
235 match err {
236 TryRecvError::Empty => {
237 continue; }
239 TryRecvError::Disconnected => {
240 break; }
242 }
243 }
244 }
245 }
246}
247
248impl<T> DedicatedWorkerHandle<T>
249where
250 T: Send + 'static,
251{
252 pub fn wait_for_completion(&self) {
254 let receivers = self.dispatcher.receivers();
255
256 loop {
257 let queues_empty = receivers.iter().all(|r| r.is_empty());
258 let active_handlers = self.dispatcher.stats().get_actively_processing();
259
260 if queues_empty && active_handlers == 0 {
261 break;
262 }
263
264 sleep(Duration::from_millis(10));
266 }
267 }
268
269 pub fn shutdown(mut self, flush_dir: Option<&PathBuf>) -> SubscriptionStats
276 where
277 T: Serialize,
278 {
279 if let Some(dir) = flush_dir {
280 self.flush_shutdown(dir);
281 } else {
282 self.complete_shutdown();
283 }
284
285 self.dispatcher.stats().snapshot()
286 }
287
288 fn complete_shutdown(&mut self) {
289 self.wait_for_completion();
290 self.shutdown_signal.store(true, Ordering::SeqCst);
291
292 self.dispatcher.close_channels();
293
294 for (i, handle) in self.handles.drain(..).enumerate() {
295 if let Err(e) = handle.join() {
296 eprintln!("Thread {i} error: {e:?}");
297 }
298 }
299 }
300
301 fn flush_shutdown(&mut self, flush_dir: &Path)
302 where
303 T: Serialize,
304 {
305 self.shutdown_signal.store(true, Ordering::Relaxed);
306
307 sleep(Duration::from_millis(50));
308 let mut flushed_messages = Vec::new();
309
310 for receiver in self.dispatcher.receivers().iter() {
311 while let Ok(message) = receiver.try_recv() {
312 flushed_messages.push(message);
313 }
314 }
315
316 let message_count = flushed_messages.len() as u64;
317 if message_count == 0 {
318 return;
319 }
320
321 let file_path = flush_dir.join(format!("{}.json", self.dispatcher.name()));
322
323 if flush_messages(&flushed_messages, &file_path).is_ok() {
324 println!(
325 "Successfully flushed {message_count} messages to: {}",
326 file_path.display()
327 );
328 self.dispatcher
329 .stats()
330 .flushed
331 .fetch_add(message_count, Ordering::Relaxed);
332 } else {
333 eprintln!("Error occurred when flushing, dropped {message_count} messages");
334 self.dispatcher
335 .stats()
336 .dropped
337 .fetch_add(message_count, Ordering::Relaxed);
338 }
339 }
340}
341
342fn flush_messages<T: Serialize>(messages: &[T], path: &PathBuf) -> Result<()> {
344 if let Some(parent) = path.parent() {
345 std::fs::create_dir_all(parent)?;
346 }
347
348 let mut file = File::create(path)?;
349 let json_str = serde_json::to_string_pretty(messages).map_err(Error::other)?;
350 writeln!(file, "{json_str}")?;
351 Ok(())
352}